use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::harness::cache::CachePolicy;
use crate::harness::message::{AssistantMessage, Message, MessageDelta};
use crate::harness::tool::{ToolDelta, ToolSchema};
use crate::harness::usage::Usage;
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoice {
#[default]
Auto,
None,
Required,
Tool(String),
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseFormat {
Text,
JsonObject,
JsonSchema {
name: String,
schema: Value,
},
Auto {
name: String,
schema: Value,
},
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelStatus {
#[default]
Stable,
Preview,
Deprecated,
Retired,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Modalities {
pub text_in: bool,
pub text_out: bool,
pub image_in: bool,
pub image_out: bool,
pub audio_in: bool,
pub audio_out: bool,
}
impl Default for Modalities {
fn default() -> Self {
Self {
text_in: true,
text_out: true,
image_in: false,
image_out: false,
audio_in: false,
audio_out: false,
}
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct ModelProfile {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub display_name: Option<String>,
#[serde(default)]
pub status: ModelStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub release_date: Option<String>,
#[serde(default)]
pub modalities: Modalities,
#[serde(default)]
pub tool_calling: bool,
#[serde(default)]
pub parallel_tool_calls: bool,
#[serde(default)]
pub streaming: bool,
#[serde(default)]
pub streaming_tool_chunks: bool,
#[serde(default)]
pub native_structured_output: bool,
#[serde(default)]
pub json_schema: bool,
#[serde(default)]
pub reasoning: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_input_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_output_tokens: Option<u64>,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct CapabilitySet {
#[serde(default)]
pub tool_calling: bool,
#[serde(default)]
pub parallel_tool_calls: bool,
#[serde(default)]
pub streaming: bool,
#[serde(default)]
pub streaming_tool_chunks: bool,
#[serde(default)]
pub native_structured_output: bool,
#[serde(default)]
pub json_schema: bool,
#[serde(default)]
pub reasoning: bool,
#[serde(default)]
pub image_in: bool,
#[serde(default)]
pub image_out: bool,
#[serde(default)]
pub audio_in: bool,
#[serde(default)]
pub audio_out: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min_input_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub min_output_tokens: Option<u64>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SegmentRole {
System,
Tools,
Instructions,
History,
Volatile,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct PromptSegment {
pub id: String,
pub role: SegmentRole,
pub cacheable: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ModelHint {
pub model: String,
#[serde(default)]
pub priority: i32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ModelResolutionSource {
RequestOverride,
StateReuse,
Hint,
AgentDefault,
RegistryDefault,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ResolvedModel {
pub name: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub requested: Option<String>,
pub source: ModelResolutionSource,
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct ModelSelection {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub requested: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub previous: Option<ResolvedModel>,
#[serde(default)]
pub reuse_previous: bool,
#[serde(default)]
pub hints: Vec<ModelHint>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_default: Option<String>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ModelRequest {
pub messages: Vec<Message>,
#[serde(default)]
pub tools: Vec<ToolSchema>,
#[serde(default)]
pub tool_choice: ToolChoice,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default)]
pub model_hints: Vec<ModelHint>,
#[serde(default)]
pub reuse_previous_model: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
#[serde(default)]
pub metadata: Value,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub cache_segments: Vec<PromptSegment>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_fingerprint: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub required_capabilities: Option<CapabilitySet>,
#[serde(default)]
pub provider_options: Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_policy: Option<CachePolicy>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub continuation_id: Option<String>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelResponse {
pub message: AssistantMessage,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub finish_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub raw: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub resolved_model: Option<ResolvedModel>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct ModelDelta {
pub call_id: String,
#[serde(default)]
pub content: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_call: Option<ToolDelta>,
}
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
pub struct ProviderError {
pub provider: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub status: Option<u16>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
pub message: String,
#[serde(default)]
pub retryable: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub raw: Option<Value>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum ModelStreamItem {
Started,
MessageDelta(MessageDelta),
ToolCallDelta(ToolDelta),
UsageDelta(Usage),
Completed(ModelResponse),
Failed(String),
ProviderFailed(ProviderError),
}
pub type ModelStream = Pin<Box<dyn Stream<Item = ModelStreamItem> + Send>>;
#[async_trait]
pub trait ChatModel<State: Send + Sync>: Send + Sync {
fn profile(&self) -> Option<&ModelProfile> {
None
}
async fn invoke(&self, state: &State, request: ModelRequest) -> Result<ModelResponse>;
async fn stream(&self, state: &State, request: ModelRequest) -> Result<ModelStream> {
let response = self.invoke(state, request).await?;
let delta = MessageDelta {
text: response.text(),
tool_call: None,
};
let items = vec![
ModelStreamItem::Started,
ModelStreamItem::MessageDelta(delta),
ModelStreamItem::Completed(response),
];
Ok(Box::pin(futures::stream::iter(items)))
}
}
pub struct ModelRegistry<State: Send + Sync> {
pub(crate) models: HashMap<String, Arc<dyn ChatModel<State>>>,
pub(crate) default: Option<String>,
}
pub struct ResolvedModelBinding<State: Send + Sync> {
pub resolved: ResolvedModel,
pub model: Arc<dyn ChatModel<State>>,
}
use crate::Result;