use crate::{Context, error::ModelError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub handle: String,
pub provider: String,
pub model: String,
pub context_window: u32,
pub input_cost_usd_per_million_tokens: Option<f64>,
pub output_cost_usd_per_million_tokens: Option<f64>,
pub supports_tool_use: bool,
pub supports_streaming: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelOutput {
pub text: Option<String>,
pub tool_calls: Vec<ToolCall>,
pub usage: Usage,
pub stop_reason: StopReason,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u32,
pub output_tokens: u32,
pub cached_input_tokens: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum StopReason {
EndTurn,
ToolUse,
MaxTokens,
StopSequence,
Other,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ModelDelta {
Text(String),
ToolCallStart { id: String, name: String },
ToolCallArgs { id: String, partial_json: String },
ToolCallEnd { id: String },
Usage(Usage),
Stop(StopReason),
}
#[async_trait]
pub trait Model: Send + Sync + 'static {
async fn complete(&self, ctx: &Context) -> Result<ModelOutput, ModelError>;
async fn stream(
&self,
ctx: &Context,
) -> Result<futures::stream::BoxStream<'static, Result<ModelDelta, ModelError>>, ModelError> {
let out = self.complete(ctx).await?;
let deltas: Vec<Result<ModelDelta, ModelError>> = out
.text
.into_iter()
.map(|t| Ok(ModelDelta::Text(t)))
.chain(std::iter::once(Ok(ModelDelta::Stop(out.stop_reason))))
.collect();
Ok(Box::pin(futures::stream::iter(deltas)))
}
fn info(&self) -> ModelInfo;
}