Skip to main content

harness_core/
model.rs

1use crate::{Context, error::ModelError};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4
5/// Information about a configured model — uniform across providers.
6///
7/// `handle` is the user-chosen logical identifier (used in logs, metrics,
8/// and `harness.toml` selectors); `model` is the wire-protocol model id
9/// sent to the provider.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelInfo {
12    pub handle: String,
13    pub provider: String,
14    pub model: String,
15    pub context_window: u32,
16    pub input_cost_usd_per_million_tokens: Option<f64>,
17    pub output_cost_usd_per_million_tokens: Option<f64>,
18    pub supports_tool_use: bool,
19    pub supports_streaming: bool,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ModelOutput {
24    pub text: Option<String>,
25    pub tool_calls: Vec<ToolCall>,
26    pub usage: Usage,
27    pub stop_reason: StopReason,
28    /// Provider-specific reasoning trace (DeepSeek `reasoning_content`,
29    /// Anthropic `thinking` blocks). Pushed back to the API verbatim on
30    /// subsequent calls; required by providers that gate on it.
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub reasoning: Option<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ToolCall {
37    pub id: String,
38    pub name: String,
39    pub args: serde_json::Value,
40}
41
42#[derive(Debug, Clone, Default, Serialize, Deserialize)]
43pub struct Usage {
44    pub input_tokens: u32,
45    pub output_tokens: u32,
46    pub cached_input_tokens: u32,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51#[non_exhaustive]
52pub enum StopReason {
53    EndTurn,
54    ToolUse,
55    MaxTokens,
56    StopSequence,
57    Other,
58}
59
60/// Streaming delta — incremental output from the model.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62#[non_exhaustive]
63pub enum ModelDelta {
64    Text(String),
65    ToolCallStart { id: String, name: String },
66    ToolCallArgs { id: String, partial_json: String },
67    ToolCallEnd { id: String },
68    Usage(Usage),
69    Stop(StopReason),
70    /// Provider-specific reasoning trace that must round-trip on the next
71    /// request. DeepSeek thinking content, Anthropic thinking blocks, and
72    /// Gemini raw `parts` (with thoughtSignatures) all flow through this —
73    /// the AgentLoop folds them into the final `ModelOutput.reasoning`
74    /// without surfacing them to user-visible token streams.
75    Reasoning(String),
76}
77
78#[async_trait]
79pub trait Model: Send + Sync + 'static {
80    async fn complete(&self, ctx: &Context) -> Result<ModelOutput, ModelError>;
81
82    /// Streaming is optional; default implementation falls back to `complete`.
83    async fn stream(
84        &self,
85        ctx: &Context,
86    ) -> Result<futures::stream::BoxStream<'static, Result<ModelDelta, ModelError>>, ModelError>
87    {
88        let out = self.complete(ctx).await?;
89        let deltas: Vec<Result<ModelDelta, ModelError>> = out
90            .text
91            .into_iter()
92            .map(|t| Ok(ModelDelta::Text(t)))
93            .chain(std::iter::once(Ok(ModelDelta::Stop(out.stop_reason))))
94            .collect();
95        Ok(Box::pin(futures::stream::iter(deltas)))
96    }
97
98    fn info(&self) -> ModelInfo;
99}
100
101/// A concrete newtype wrapping a boxed model, so an `Arc<dyn Model>` can be used
102/// where a concrete `M: Model` is required (e.g. `Subagent::new` / `AgentLoop<M>`).
103///
104/// We deliberately do NOT `impl Model for Arc<dyn Model>` directly. Doing so
105/// changes `.stream()` method resolution on EVERY `Arc<dyn Model>` value in the
106/// program (from a deref to `dyn Model` into the Arc impl's `async fn stream`
107/// RPITIT), and proving that boxed streaming future is `Send` inside a `Send`
108/// context (e.g. an axum handler driving the streaming loop) overflows the
109/// auto-trait solver (E0275). Wrapping in this concrete newtype gives callers
110/// `DynModel: Model` without touching resolution for bare `Arc<dyn Model>`.
111pub struct DynModel(pub std::sync::Arc<dyn Model>);
112
113#[async_trait]
114impl Model for DynModel {
115    async fn complete(&self, ctx: &Context) -> Result<ModelOutput, ModelError> {
116        self.0.complete(ctx).await
117    }
118    async fn stream(
119        &self,
120        ctx: &Context,
121    ) -> Result<futures::stream::BoxStream<'static, Result<ModelDelta, ModelError>>, ModelError>
122    {
123        self.0.stream(ctx).await
124    }
125    fn info(&self) -> ModelInfo {
126        self.0.info()
127    }
128}
129
130#[cfg(test)]
131mod arc_model_tests {
132    use super::*;
133    use std::sync::Arc;
134
135    struct Dummy;
136
137    #[async_trait]
138    impl Model for Dummy {
139        async fn complete(&self, _ctx: &Context) -> Result<ModelOutput, ModelError> {
140            Ok(ModelOutput {
141                text: Some("ok".into()),
142                tool_calls: vec![],
143                usage: Usage::default(),
144                stop_reason: StopReason::EndTurn,
145                reasoning: None,
146            })
147        }
148        fn info(&self) -> ModelInfo {
149            ModelInfo {
150                handle: "dummy".into(),
151                provider: "test".into(),
152                model: "dummy".into(),
153                context_window: 8192,
154                input_cost_usd_per_million_tokens: None,
155                output_cost_usd_per_million_tokens: None,
156                supports_tool_use: false,
157                supports_streaming: false,
158            }
159        }
160    }
161
162    fn assert_is_model<M: Model>(_m: &M) {}
163
164    #[tokio::test]
165    async fn dyn_model_wrapper_is_a_model() {
166        let m: Arc<dyn Model> = Arc::new(Dummy);
167        let wrapped = DynModel(m);
168        assert_is_model(&wrapped); // compiles only if DynModel: Model
169        let out = wrapped
170            .complete(&Context::new(crate::Task {
171                description: "x".into(),
172                source: None,
173                deadline: None,
174            }))
175            .await
176            .unwrap();
177        assert_eq!(out.text.as_deref(), Some("ok"));
178    }
179}