1use crate::{Context, error::ModelError};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelInfo {
12 pub handle: String,
13 pub provider: String,
14 pub model: String,
15 pub context_window: u32,
16 pub input_cost_usd_per_million_tokens: Option<f64>,
17 pub output_cost_usd_per_million_tokens: Option<f64>,
18 pub supports_tool_use: bool,
19 pub supports_streaming: bool,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelOutput {
24 pub text: Option<String>,
25 pub tool_calls: Vec<ToolCall>,
26 pub usage: Usage,
27 pub stop_reason: StopReason,
28 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub reasoning: Option<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ToolCall {
37 pub id: String,
38 pub name: String,
39 pub args: serde_json::Value,
40}
41
42#[derive(Debug, Clone, Default, Serialize, Deserialize)]
43pub struct Usage {
44 pub input_tokens: u32,
45 pub output_tokens: u32,
46 pub cached_input_tokens: u32,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51#[non_exhaustive]
52pub enum StopReason {
53 EndTurn,
54 ToolUse,
55 MaxTokens,
56 StopSequence,
57 Other,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62#[non_exhaustive]
63pub enum ModelDelta {
64 Text(String),
65 ToolCallStart { id: String, name: String },
66 ToolCallArgs { id: String, partial_json: String },
67 ToolCallEnd { id: String },
68 Usage(Usage),
69 Stop(StopReason),
70 Reasoning(String),
76}
77
78#[async_trait]
79pub trait Model: Send + Sync + 'static {
80 async fn complete(&self, ctx: &Context) -> Result<ModelOutput, ModelError>;
81
82 async fn stream(
84 &self,
85 ctx: &Context,
86 ) -> Result<futures::stream::BoxStream<'static, Result<ModelDelta, ModelError>>, ModelError>
87 {
88 let out = self.complete(ctx).await?;
89 let deltas: Vec<Result<ModelDelta, ModelError>> = out
90 .text
91 .into_iter()
92 .map(|t| Ok(ModelDelta::Text(t)))
93 .chain(std::iter::once(Ok(ModelDelta::Stop(out.stop_reason))))
94 .collect();
95 Ok(Box::pin(futures::stream::iter(deltas)))
96 }
97
98 fn info(&self) -> ModelInfo;
99}
100
101pub struct DynModel(pub std::sync::Arc<dyn Model>);
112
113#[async_trait]
114impl Model for DynModel {
115 async fn complete(&self, ctx: &Context) -> Result<ModelOutput, ModelError> {
116 self.0.complete(ctx).await
117 }
118 async fn stream(
119 &self,
120 ctx: &Context,
121 ) -> Result<futures::stream::BoxStream<'static, Result<ModelDelta, ModelError>>, ModelError>
122 {
123 self.0.stream(ctx).await
124 }
125 fn info(&self) -> ModelInfo {
126 self.0.info()
127 }
128}
129
130#[cfg(test)]
131mod arc_model_tests {
132 use super::*;
133 use std::sync::Arc;
134
135 struct Dummy;
136
137 #[async_trait]
138 impl Model for Dummy {
139 async fn complete(&self, _ctx: &Context) -> Result<ModelOutput, ModelError> {
140 Ok(ModelOutput {
141 text: Some("ok".into()),
142 tool_calls: vec![],
143 usage: Usage::default(),
144 stop_reason: StopReason::EndTurn,
145 reasoning: None,
146 })
147 }
148 fn info(&self) -> ModelInfo {
149 ModelInfo {
150 handle: "dummy".into(),
151 provider: "test".into(),
152 model: "dummy".into(),
153 context_window: 8192,
154 input_cost_usd_per_million_tokens: None,
155 output_cost_usd_per_million_tokens: None,
156 supports_tool_use: false,
157 supports_streaming: false,
158 }
159 }
160 }
161
162 fn assert_is_model<M: Model>(_m: &M) {}
163
164 #[tokio::test]
165 async fn dyn_model_wrapper_is_a_model() {
166 let m: Arc<dyn Model> = Arc::new(Dummy);
167 let wrapped = DynModel(m);
168 assert_is_model(&wrapped); let out = wrapped
170 .complete(&Context::new(crate::Task {
171 description: "x".into(),
172 source: None,
173 deadline: None,
174 }))
175 .await
176 .unwrap();
177 assert_eq!(out.text.as_deref(), Some("ok"));
178 }
179}