use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::harness::message::{AssistantMessage, Message};
use crate::harness::tool::{ToolDelta, ToolSchema};
use crate::harness::usage::Usage;
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
#[default]
Auto,
None,
Required,
Tool(String),
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema {
name: String,
schema: Value,
},
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SegmentRole {
System,
Tools,
Instructions,
History,
Volatile,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct PromptSegment {
pub id: String,
pub role: SegmentRole,
pub cacheable: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ModelHint {
pub model: String,
#[serde(default)]
pub priority: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelResolutionSource {
RequestOverride,
StateReuse,
Hint,
AgentDefault,
RegistryDefault,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ResolvedModel {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub requested: Option<String>,
pub source: ModelResolutionSource,
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct ModelSelection {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub requested: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub previous: Option<ResolvedModel>,
#[serde(default)]
pub reuse_previous: bool,
#[serde(default)]
pub hints: Vec<ModelHint>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_default: Option<String>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ModelRequest {
pub messages: Vec<Message>,
#[serde(default)]
pub tools: Vec<ToolSchema>,
#[serde(default)]
pub tool_choice: ToolChoice,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default)]
pub model_hints: Vec<ModelHint>,
#[serde(default)]
pub reuse_previous_model: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
#[serde(default)]
pub metadata: Value,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub cache_segments: Vec<PromptSegment>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_fingerprint: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelResponse {
pub message: AssistantMessage,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub raw: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub resolved_model: Option<ResolvedModel>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ModelDelta {
pub call_id: String,
#[serde(default)]
pub content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call: Option<ToolDelta>,
}
#[async_trait]
pub trait ChatModel<State: Send + Sync>: Send + Sync {
async fn invoke(&self, state: &State, request: ModelRequest) -> Result<ModelResponse>;
async fn stream(&self, state: &State, request: ModelRequest) -> Result<Vec<ModelDelta>> {
let response = self.invoke(state, request).await?;
Ok(vec![ModelDelta {
call_id: response.message.id.clone().unwrap_or_default(),
content: response.text(),
tool_call: None,
}])
}
}
pub struct ModelRegistry<State: Send + Sync> {
pub(crate) models: HashMap<String, Arc<dyn ChatModel<State>>>,
pub(crate) default: Option<String>,
}
pub struct ResolvedModelBinding<State: Send + Sync> {
pub resolved: ResolvedModel,
pub model: Arc<dyn ChatModel<State>>,
}
use crate::Result;