mod types;
use std::sync::Arc;
use serde_json::Value;
use crate::harness::message::{AssistantMessage, ContentBlock, Message};
use crate::harness::tool::{ToolCall, ToolSchema};
use crate::harness::usage::Usage;
pub use types::*;
impl ResponseFormat {
pub fn json_schema(name: impl Into<String>, schema: Value) -> Self {
ResponseFormat::JsonSchema {
name: name.into(),
schema,
}
}
}
impl ModelRequest {
pub fn new(messages: Vec<Message>) -> Self {
Self {
messages,
..Self::default()
}
}
pub fn with_tools(mut self, tools: Vec<ToolSchema>) -> Self {
self.tools = tools;
self
}
pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self {
self.tool_choice = choice;
self
}
pub fn with_response_format(mut self, format: ResponseFormat) -> Self {
self.response_format = Some(format);
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_model_hint(mut self, hint: ModelHint) -> Self {
self.model_hints.push(hint);
self
}
pub fn with_reuse_previous_model(mut self, reuse: bool) -> Self {
self.reuse_previous_model = reuse;
self
}
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn with_cache_segments(mut self, segments: Vec<PromptSegment>) -> Self {
self.cache_segments = segments;
self
}
pub fn cacheable_prefix_ids(&self) -> Vec<String> {
self.cache_segments
.iter()
.filter(|s| s.cacheable)
.map(|s| s.id.clone())
.collect()
}
}
impl ModelResponse {
pub fn assistant(content: impl Into<String>) -> Self {
Self {
message: AssistantMessage {
id: None,
content: vec![ContentBlock::Text(content.into())],
tool_calls: Vec::new(),
usage: None,
},
usage: None,
finish_reason: None,
raw: None,
resolved_model: None,
}
}
pub fn with_usage(mut self, usage: Usage) -> Self {
self.message.usage = Some(usage);
self.usage = Some(usage);
self
}
pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
self.finish_reason = Some(reason.into());
self
}
pub fn with_resolved_model(mut self, resolved: ResolvedModel) -> Self {
self.resolved_model = Some(resolved);
self
}
pub fn tool_calls(&self) -> &[ToolCall] {
&self.message.tool_calls
}
pub fn text(&self) -> String {
Message::Assistant(self.message.clone()).text()
}
}
impl<State: Send + Sync> ModelRegistry<State> {
pub fn new() -> Self {
Self {
models: std::collections::HashMap::new(),
default: None,
}
}
pub fn register(
&mut self,
name: impl Into<String>,
model: Arc<dyn ChatModel<State>>,
) -> &mut Self {
let name = name.into();
if self.default.is_none() {
self.default = Some(name.clone());
}
self.models.insert(name, model);
self
}
pub fn set_default(&mut self, name: impl Into<String>) -> &mut Self {
self.default = Some(name.into());
self
}
pub fn get(&self, name: &str) -> Option<Arc<dyn ChatModel<State>>> {
self.models.get(name).cloned()
}
pub fn default_model(&self) -> Option<Arc<dyn ChatModel<State>>> {
self.default.as_deref().and_then(|name| self.get(name))
}
pub fn default_name(&self) -> Option<&str> {
self.default.as_deref()
}
pub fn resolve(&self, selection: ModelSelection) -> Option<ResolvedModelBinding<State>> {
if let Some(requested) = selection.requested
&& let Some(model) = self.get(&requested)
{
return Some(ResolvedModelBinding {
resolved: ResolvedModel {
name: requested.clone(),
requested: Some(requested),
source: ModelResolutionSource::RequestOverride,
},
model,
});
}
if selection.reuse_previous
&& let Some(previous) = selection.previous
&& let Some(model) = self.get(&previous.name)
{
return Some(ResolvedModelBinding {
resolved: ResolvedModel {
name: previous.name,
requested: previous.requested,
source: ModelResolutionSource::StateReuse,
},
model,
});
}
let mut hints: Vec<(usize, ModelHint)> = selection.hints.into_iter().enumerate().collect();
hints.sort_by(|(left_index, left), (right_index, right)| {
right
.priority
.cmp(&left.priority)
.then_with(|| left_index.cmp(right_index))
});
for (_, hint) in hints {
if let Some(model) = self.get(&hint.model) {
return Some(ResolvedModelBinding {
resolved: ResolvedModel {
name: hint.model.clone(),
requested: Some(hint.model),
source: ModelResolutionSource::Hint,
},
model,
});
}
}
if let Some(agent_default) = selection.agent_default
&& let Some(model) = self.get(&agent_default)
{
return Some(ResolvedModelBinding {
resolved: ResolvedModel {
name: agent_default.clone(),
requested: Some(agent_default),
source: ModelResolutionSource::AgentDefault,
},
model,
});
}
let name = self.default_name()?.to_string();
self.default_model().map(|model| ResolvedModelBinding {
resolved: ResolvedModel {
name,
requested: None,
source: ModelResolutionSource::RegistryDefault,
},
model,
})
}
pub fn resolve_request(
&self,
request: &ModelRequest,
agent_default: Option<&str>,
previous: Option<ResolvedModel>,
) -> Option<ResolvedModelBinding<State>> {
self.resolve(ModelSelection {
requested: request.model.clone(),
previous,
reuse_previous: request.reuse_previous_model,
hints: request.model_hints.clone(),
agent_default: agent_default.map(ToOwned::to_owned),
})
}
pub fn names(&self) -> Vec<String> {
let mut names: Vec<String> = self.models.keys().cloned().collect();
names.sort();
names
}
}
impl<State: Send + Sync> Default for ModelRegistry<State> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod test;