1use crate::{Context, error::ModelError};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelInfo {
12 pub handle: String,
13 pub provider: String,
14 pub model: String,
15 pub context_window: u32,
16 pub input_cost_usd_per_million_tokens: Option<f64>,
17 pub output_cost_usd_per_million_tokens: Option<f64>,
18 pub supports_tool_use: bool,
19 pub supports_streaming: bool,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelOutput {
24 pub text: Option<String>,
25 pub tool_calls: Vec<ToolCall>,
26 pub usage: Usage,
27 pub stop_reason: StopReason,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub reasoning: Option<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ToolCall {
37 pub id: String,
38 pub name: String,
39 pub args: serde_json::Value,
40}
41
42#[derive(Debug, Clone, Default, Serialize, Deserialize)]
43pub struct Usage {
44 pub input_tokens: u32,
45 pub output_tokens: u32,
46 pub cached_input_tokens: u32,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51#[non_exhaustive]
52pub enum StopReason {
53 EndTurn,
54 ToolUse,
55 MaxTokens,
56 StopSequence,
57 Other,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62#[non_exhaustive]
63pub enum ModelDelta {
64 Text(String),
65 ToolCallStart {
66 id: String,
67 name: String,
68 },
69 ToolCallArgs {
70 id: String,
71 partial_json: String,
72 },
73 ToolCallEnd {
74 id: String,
75 },
76 Usage(Usage),
77 Stop(StopReason),
78 Reasoning(String),
84}
85
86#[async_trait]
87pub trait Model: Send + Sync + 'static {
88 async fn complete(&self, ctx: &Context) -> Result<ModelOutput, ModelError>;
89
90 async fn stream(
92 &self,
93 ctx: &Context,
94 ) -> Result<futures::stream::BoxStream<'static, Result<ModelDelta, ModelError>>, ModelError>
95 {
96 let out = self.complete(ctx).await?;
97 let deltas: Vec<Result<ModelDelta, ModelError>> = out
98 .text
99 .into_iter()
100 .map(|t| Ok(ModelDelta::Text(t)))
101 .chain(std::iter::once(Ok(ModelDelta::Stop(out.stop_reason))))
102 .collect();
103 Ok(Box::pin(futures::stream::iter(deltas)))
104 }
105
106 fn info(&self) -> ModelInfo;
107}
108
109pub struct DynModel(pub std::sync::Arc<dyn Model>);
120
121#[async_trait]
122impl Model for DynModel {
123 async fn complete(&self, ctx: &Context) -> Result<ModelOutput, ModelError> {
124 self.0.complete(ctx).await
125 }
126 async fn stream(
127 &self,
128 ctx: &Context,
129 ) -> Result<futures::stream::BoxStream<'static, Result<ModelDelta, ModelError>>, ModelError>
130 {
131 self.0.stream(ctx).await
132 }
133 fn info(&self) -> ModelInfo {
134 self.0.info()
135 }
136}
137
138#[cfg(test)]
139mod arc_model_tests {
140 use super::*;
141 use std::sync::Arc;
142
143 struct Dummy;
144
145 #[async_trait]
146 impl Model for Dummy {
147 async fn complete(&self, _ctx: &Context) -> Result<ModelOutput, ModelError> {
148 Ok(ModelOutput {
149 text: Some("ok".into()),
150 tool_calls: vec![],
151 usage: Usage::default(),
152 stop_reason: StopReason::EndTurn,
153 reasoning: None,
154 })
155 }
156 fn info(&self) -> ModelInfo {
157 ModelInfo {
158 handle: "dummy".into(),
159 provider: "test".into(),
160 model: "dummy".into(),
161 context_window: 8192,
162 input_cost_usd_per_million_tokens: None,
163 output_cost_usd_per_million_tokens: None,
164 supports_tool_use: false,
165 supports_streaming: false,
166 }
167 }
168 }
169
170 fn assert_is_model<M: Model>(_m: &M) {}
171
172 #[tokio::test]
173 async fn dyn_model_wrapper_is_a_model() {
174 let m: Arc<dyn Model> = Arc::new(Dummy);
175 let wrapped = DynModel(m);
176 assert_is_model(&wrapped); let out = wrapped
178 .complete(&Context::new(crate::Task {
179 description: "x".into(),
180 source: None,
181 deadline: None,
182 }))
183 .await
184 .unwrap();
185 assert_eq!(out.text.as_deref(), Some("ok"));
186 }
187}