use crate::error::{AgentLoopError, Result};
use crate::openresponses_protocol::{CompactRequest, CompactResponse};
use crate::runtime_agent::RuntimeAgent;
use crate::tool_types::{ToolCall, ToolDefinition};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use futures::Stream;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>;
#[derive(Debug, Clone)]
pub enum LlmStreamEvent {
TextDelta(String),
ThinkingDelta(String),
ThinkingSignature(String),
ReasonItem {
provider: String,
model: Option<String>,
item_id: String,
encrypted_content: Option<String>,
summary: Vec<String>,
token_count: Option<u32>,
},
ToolCalls(Vec<ToolCall>),
Done(Box<LlmCompletionMetadata>),
Error(String),
}
#[derive(Debug, Clone)]
pub struct DiscoveredModel {
pub model_id: String,
pub display_name: Option<String>,
pub created_at: Option<DateTime<Utc>>,
pub owned_by: Option<String>,
pub discovered_profile: Option<crate::llm_models::LlmModelProfile>,
}
#[derive(Debug, Clone, Default)]
pub struct LlmCompletionMetadata {
pub total_tokens: Option<u32>,
pub prompt_tokens: Option<u32>,
pub completion_tokens: Option<u32>,
pub cache_read_tokens: Option<u32>,
pub cache_creation_tokens: Option<u32>,
pub provider_cost_usd: Option<f64>,
pub model: Option<String>,
pub finish_reason: Option<String>,
pub retry_metadata: Option<crate::llm_retry::RetryMetadata>,
pub response_id: Option<String>,
pub phase: Option<String>,
}
#[async_trait]
pub trait LlmDriver: Send + Sync {
async fn chat_completion_stream(
&self,
messages: Vec<LlmMessage>,
config: &LlmCallConfig,
) -> Result<LlmResponseStream>;
async fn chat_completion(
&self,
messages: Vec<LlmMessage>,
config: &LlmCallConfig,
) -> Result<LlmResponse> {
use futures::StreamExt;
let mut stream = self.chat_completion_stream(messages, config).await?;
let mut text = String::new();
let mut thinking = String::new();
let mut thinking_signature: Option<String> = None;
let mut tool_calls = Vec::new();
let mut metadata = LlmCompletionMetadata::default();
while let Some(event) = stream.next().await {
match event? {
LlmStreamEvent::TextDelta(delta) => text.push_str(&delta),
LlmStreamEvent::ThinkingDelta(delta) => thinking.push_str(&delta),
LlmStreamEvent::ThinkingSignature(sig) => thinking_signature = Some(sig),
LlmStreamEvent::ReasonItem {
encrypted_content, ..
} => {
if let Some(sig) = encrypted_content {
thinking_signature = Some(sig);
}
}
LlmStreamEvent::ToolCalls(calls) => tool_calls = calls,
LlmStreamEvent::Done(meta) => metadata = *meta,
LlmStreamEvent::Error(err) => return Err(crate::error::AgentLoopError::llm(err)),
}
}
Ok(LlmResponse {
text,
thinking: if thinking.is_empty() {
None
} else {
Some(thinking)
},
thinking_signature,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
metadata,
})
}
async fn list_models(&self) -> Result<Option<Vec<DiscoveredModel>>> {
Ok(None)
}
fn supports_compact(&self) -> bool {
false
}
async fn compact(&self, _request: CompactRequest) -> Result<Option<CompactResponse>> {
Ok(None)
}
}
#[async_trait]
impl LlmDriver for Box<dyn LlmDriver> {
async fn chat_completion_stream(
&self,
messages: Vec<LlmMessage>,
config: &LlmCallConfig,
) -> Result<LlmResponseStream> {
(**self).chat_completion_stream(messages, config).await
}
async fn chat_completion(
&self,
messages: Vec<LlmMessage>,
config: &LlmCallConfig,
) -> Result<LlmResponse> {
(**self).chat_completion(messages, config).await
}
async fn list_models(&self) -> Result<Option<Vec<DiscoveredModel>>> {
(**self).list_models().await
}
fn supports_compact(&self) -> bool {
(**self).supports_compact()
}
async fn compact(&self, request: CompactRequest) -> Result<Option<CompactResponse>> {
(**self).compact(request).await
}
}
#[derive(Debug, Clone)]
pub struct LlmMessage {
pub role: LlmMessageRole,
pub content: LlmMessageContent,
pub tool_calls: Option<Vec<ToolCall>>,
pub tool_call_id: Option<String>,
pub phase: Option<crate::message::ExecutionPhase>,
pub thinking: Option<String>,
pub thinking_signature: Option<String>,
}
impl LlmMessage {
pub fn text(role: LlmMessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: LlmMessageContent::Text(content.into()),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
pub fn parts(role: LlmMessageRole, parts: Vec<LlmContentPart>) -> Self {
Self {
role,
content: LlmMessageContent::Parts(parts),
tool_calls: None,
tool_call_id: None,
phase: None,
thinking: None,
thinking_signature: None,
}
}
pub fn content_as_text(&self) -> String {
self.content.to_text()
}
pub fn prepend_text_prefix(&mut self, prefix: &str) {
match &mut self.content {
LlmMessageContent::Text(text) => {
*text = format!("{}{}", prefix, text);
}
LlmMessageContent::Parts(parts) => {
for part in parts.iter_mut() {
if let LlmContentPart::Text { text } = part {
*text = format!("{}{}", prefix, text);
return;
}
}
parts.insert(
0,
LlmContentPart::Text {
text: prefix.to_string(),
},
);
}
}
}
}
#[derive(Debug, Clone)]
pub enum LlmMessageContent {
Text(String),
Parts(Vec<LlmContentPart>),
}
impl LlmMessageContent {
pub fn to_text(&self) -> String {
match self {
LlmMessageContent::Text(s) => s.clone(),
LlmMessageContent::Parts(parts) => parts
.iter()
.filter_map(|p| match p {
LlmContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join(""),
}
}
pub fn is_text(&self) -> bool {
matches!(self, LlmMessageContent::Text(_))
}
pub fn is_parts(&self) -> bool {
matches!(self, LlmMessageContent::Parts(_))
}
}
impl From<String> for LlmMessageContent {
fn from(s: String) -> Self {
LlmMessageContent::Text(s)
}
}
impl From<&str> for LlmMessageContent {
fn from(s: &str) -> Self {
LlmMessageContent::Text(s.to_string())
}
}
#[derive(Debug, Clone)]
pub enum LlmContentPart {
Text { text: String },
Image { url: String },
Audio { url: String },
}
impl LlmContentPart {
pub fn text(text: impl Into<String>) -> Self {
LlmContentPart::Text { text: text.into() }
}
pub fn image(url: impl Into<String>) -> Self {
LlmContentPart::Image { url: url.into() }
}
pub fn audio(url: impl Into<String>) -> Self {
LlmContentPart::Audio { url: url.into() }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LlmMessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct ToolSearchConfig {
pub enabled: bool,
pub threshold: usize,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
#[serde(rename_all = "snake_case")]
pub enum PromptCacheStrategy {
#[default]
Auto,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct PromptCacheConfig {
pub enabled: bool,
#[serde(default)]
pub strategy: PromptCacheStrategy,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gemini_cached_content: Option<String>,
}
#[derive(Debug, Clone)]
pub struct LlmCallConfig {
pub model: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub tools: Vec<ToolDefinition>,
pub reasoning_effort: Option<String>,
pub metadata: HashMap<String, String>,
pub previous_response_id: Option<String>,
pub tool_search: Option<ToolSearchConfig>,
pub prompt_cache: Option<PromptCacheConfig>,
}
impl From<&RuntimeAgent> for LlmCallConfig {
fn from(runtime_agent: &RuntimeAgent) -> Self {
Self {
model: runtime_agent.model.clone(),
temperature: runtime_agent.temperature,
max_tokens: runtime_agent.max_tokens,
tools: runtime_agent.tools.clone(),
reasoning_effort: None, metadata: HashMap::new(), previous_response_id: None,
tool_search: runtime_agent.tool_search.clone(),
prompt_cache: runtime_agent.prompt_cache.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub text: String,
pub thinking: Option<String>,
pub thinking_signature: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub metadata: LlmCompletionMetadata,
}
pub struct LlmCallConfigBuilder {
config: LlmCallConfig,
}
impl LlmCallConfigBuilder {
pub fn from(runtime_agent: &RuntimeAgent) -> Self {
Self {
config: LlmCallConfig::from(runtime_agent),
}
}
pub fn reasoning_effort(mut self, effort: impl Into<String>) -> Self {
self.config.reasoning_effort = Some(effort.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.config.model = model.into();
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.config.temperature = Some(temp);
self
}
pub fn max_tokens(mut self, tokens: u32) -> Self {
self.config.max_tokens = Some(tokens);
self
}
pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
self.config.tools = tools;
self
}
pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.config.metadata = metadata;
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.config.metadata.insert(key.into(), value.into());
self
}
pub fn previous_response_id(mut self, id: Option<String>) -> Self {
self.config.previous_response_id = id;
self
}
pub fn tool_search(mut self, config: ToolSearchConfig) -> Self {
self.config.tool_search = Some(config);
self
}
pub fn prompt_cache(mut self, config: PromptCacheConfig) -> Self {
self.config.prompt_cache = Some(config);
self
}
pub fn build(self) -> LlmCallConfig {
self.config
}
}
impl From<&crate::message::Message> for LlmMessage {
fn from(msg: &crate::message::Message) -> Self {
let role = match msg.role {
crate::message::MessageRole::System => LlmMessageRole::System,
crate::message::MessageRole::User => LlmMessageRole::User,
crate::message::MessageRole::Agent => LlmMessageRole::Assistant,
crate::message::MessageRole::ToolResult => LlmMessageRole::Tool,
};
let tool_calls: Vec<ToolCall> = msg
.tool_calls()
.into_iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
})
.collect();
LlmMessage {
role,
content: LlmMessageContent::Text(msg.content_to_llm_string()),
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: msg.tool_call_id().map(|s| s.to_string()),
phase: msg.phase,
thinking: msg.thinking.clone(),
thinking_signature: msg.thinking_signature.clone(),
}
}
}
use crate::traits::ResolvedImage;
use uuid::Uuid;
impl LlmMessage {
pub fn from_message_with_images(
msg: &crate::message::Message,
resolved_images: &HashMap<Uuid, ResolvedImage>,
) -> Self {
use crate::message::{ContentPart, MessageRole};
let role = match msg.role {
MessageRole::System => LlmMessageRole::System,
MessageRole::User => LlmMessageRole::User,
MessageRole::Agent => LlmMessageRole::Assistant,
MessageRole::ToolResult => LlmMessageRole::Tool,
};
let mut parts: Vec<LlmContentPart> = Vec::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
for part in &msg.content {
match part {
ContentPart::Text(t) => {
parts.push(LlmContentPart::Text {
text: t.text.clone(),
});
}
ContentPart::Image(img) => {
if let Some(url) = &img.url {
parts.push(LlmContentPart::Image { url: url.clone() });
} else if let (Some(base64), Some(media_type)) = (&img.base64, &img.media_type)
{
let data_url = format!("data:{};base64,{}", media_type, base64);
parts.push(LlmContentPart::Image { url: data_url });
}
}
ContentPart::ImageFile(img_file) => {
if let Some(resolved) = resolved_images.get(&img_file.image_id.uuid()) {
parts.push(LlmContentPart::Image {
url: resolved.to_data_url(),
});
} else {
parts.push(LlmContentPart::Text {
text: format!("[Image not found: {}]", img_file.image_id),
});
}
}
ContentPart::ToolCall(tc) => {
tool_calls.push(ToolCall {
id: tc.id.clone(),
name: tc.name.clone(),
arguments: tc.arguments.clone(),
});
}
ContentPart::ToolResult(tr) => {
let text = if let Some(err) = &tr.error {
format!("Tool error: {}", err)
} else if let Some(res) = &tr.result {
serde_json::to_string(res).unwrap_or_else(|_| "{}".to_string())
} else {
"{}".to_string()
};
let text = truncate_tool_result(text);
parts.push(LlmContentPart::Text { text });
}
}
}
let content = if parts.len() == 1 && matches!(&parts[0], LlmContentPart::Text { .. }) {
if let LlmContentPart::Text { text } = &parts[0] {
LlmMessageContent::Text(text.clone())
} else {
LlmMessageContent::Parts(parts)
}
} else if parts.is_empty() {
LlmMessageContent::Text(String::new())
} else {
LlmMessageContent::Parts(parts)
};
LlmMessage {
role,
content,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: msg.tool_call_id().map(|s| s.to_string()),
phase: msg.phase,
thinking: msg.thinking.clone(),
thinking_signature: msg.thinking_signature.clone(),
}
}
pub fn message_has_image_files(msg: &crate::message::Message) -> bool {
msg.content.iter().any(|p| p.is_image_file())
}
pub fn extract_image_file_ids(msg: &crate::message::Message) -> Vec<Uuid> {
msg.content
.iter()
.filter_map(|p| match p {
crate::message::ContentPart::ImageFile(f) => Some(f.image_id.uuid()),
_ => None,
})
.collect()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ProviderType {
OpenAI,
AzureOpenAI,
OpenAICompletions,
Anthropic,
Gemini,
LlmSim,
}
impl std::str::FromStr for ProviderType {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openai" => Ok(ProviderType::OpenAI),
"azure_openai" => Ok(ProviderType::AzureOpenAI),
"openai_completions" => Ok(ProviderType::OpenAICompletions),
"anthropic" => Ok(ProviderType::Anthropic),
"gemini" => Ok(ProviderType::Gemini),
"llmsim" => Ok(ProviderType::LlmSim),
_ => Err(format!("Unknown provider type: {}", s)),
}
}
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProviderType::OpenAI => write!(f, "openai"),
ProviderType::AzureOpenAI => write!(f, "azure_openai"),
ProviderType::OpenAICompletions => write!(f, "openai_completions"),
ProviderType::Anthropic => write!(f, "anthropic"),
ProviderType::Gemini => write!(f, "gemini"),
ProviderType::LlmSim => write!(f, "llmsim"),
}
}
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub provider_type: ProviderType,
pub api_key: Option<String>,
pub base_url: Option<String>,
}
impl ProviderConfig {
pub fn new(provider_type: ProviderType) -> Self {
Self {
provider_type,
api_key: None,
base_url: None,
}
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = Some(base_url.into());
self
}
}
pub type BoxedLlmDriver = Box<dyn LlmDriver>;
pub type DriverFactory = Arc<dyn Fn(&str, Option<&str>) -> BoxedLlmDriver + Send + Sync>;
#[derive(Clone, Default)]
pub struct DriverRegistry {
factories: HashMap<ProviderType, DriverFactory>,
}
impl DriverRegistry {
pub fn new() -> Self {
Self {
factories: HashMap::new(),
}
}
pub fn register<F>(&mut self, provider_type: ProviderType, factory: F)
where
F: Fn(&str, Option<&str>) -> BoxedLlmDriver + Send + Sync + 'static,
{
self.factories.insert(provider_type, Arc::new(factory));
}
pub fn create_driver(&self, config: &ProviderConfig) -> Result<BoxedLlmDriver> {
let api_key = if config.provider_type == ProviderType::LlmSim {
config.api_key.as_deref().unwrap_or("")
} else {
config.api_key.as_ref().ok_or_else(|| {
AgentLoopError::llm(
"API key is required. Configure the API key in provider settings.",
)
})?
};
let factory = self.factories.get(&config.provider_type).ok_or_else(|| {
AgentLoopError::driver_not_registered(config.provider_type.to_string())
})?;
Ok(factory(api_key, config.base_url.as_deref()))
}
pub fn has_driver(&self, provider_type: &ProviderType) -> bool {
self.factories.contains_key(provider_type)
}
pub fn registered_providers(&self) -> Vec<ProviderType> {
self.factories.keys().cloned().collect()
}
}
const MAX_TOOL_RESULT_BYTES: usize = 64 * 1024;
const TRUNCATION_SUFFIX: &str =
"\n\n[Output truncated — exceeded 64 KiB limit. Try quiet flags, pipes, or redirect to file.]";
fn truncate_tool_result(text: String) -> String {
if text.len() <= MAX_TOOL_RESULT_BYTES {
return text;
}
let content_budget = MAX_TOOL_RESULT_BYTES.saturating_sub(TRUNCATION_SUFFIX.len());
let mut end = content_budget;
while end > 0 && !text.is_char_boundary(end) {
end -= 1;
}
let mut truncated = text[..end].to_string();
truncated.push_str(TRUNCATION_SUFFIX);
truncated
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_call_config_builder_from_runtime_agent() {
let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
let llm_config = LlmCallConfigBuilder::from(&runtime_agent).build();
assert_eq!(llm_config.model, "gpt-4o");
assert!(llm_config.reasoning_effort.is_none());
assert!(llm_config.temperature.is_none());
assert!(llm_config.max_tokens.is_none());
assert!(llm_config.tools.is_empty());
assert!(llm_config.metadata.is_empty());
}
#[test]
fn test_llm_call_config_builder_with_metadata() {
let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
.with_metadata("session_id", "session_abc123")
.with_metadata("agent_id", "agent_xyz789")
.build();
assert_eq!(
llm_config.metadata.get("session_id"),
Some(&"session_abc123".to_string())
);
assert_eq!(
llm_config.metadata.get("agent_id"),
Some(&"agent_xyz789".to_string())
);
}
#[test]
fn test_llm_call_config_builder_with_metadata_hashmap() {
let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
let mut metadata = HashMap::new();
metadata.insert("key1".to_string(), "value1".to_string());
metadata.insert("key2".to_string(), "value2".to_string());
let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
.metadata(metadata)
.build();
assert_eq!(llm_config.metadata.get("key1"), Some(&"value1".to_string()));
assert_eq!(llm_config.metadata.get("key2"), Some(&"value2".to_string()));
}
#[test]
fn test_llm_call_config_builder_with_reasoning_effort() {
let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
.reasoning_effort("high")
.build();
assert_eq!(llm_config.reasoning_effort, Some("high".to_string()));
}
#[test]
fn test_llm_call_config_builder_with_all_options() {
let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
.model("claude-3-opus")
.reasoning_effort("medium")
.temperature(0.7)
.max_tokens(1000)
.build();
assert_eq!(llm_config.model, "claude-3-opus");
assert_eq!(llm_config.reasoning_effort, Some("medium".to_string()));
assert_eq!(llm_config.temperature, Some(0.7));
assert_eq!(llm_config.max_tokens, Some(1000));
}
#[test]
fn test_provider_type_parsing() {
assert_eq!(
"openai".parse::<ProviderType>().unwrap(),
ProviderType::OpenAI
);
assert_eq!(
"openai_completions".parse::<ProviderType>().unwrap(),
ProviderType::OpenAICompletions
);
assert_eq!(
"azure_openai".parse::<ProviderType>().unwrap(),
ProviderType::AzureOpenAI
);
assert_eq!(
"anthropic".parse::<ProviderType>().unwrap(),
ProviderType::Anthropic
);
assert_eq!(
"gemini".parse::<ProviderType>().unwrap(),
ProviderType::Gemini
);
assert!("ollama".parse::<ProviderType>().is_err());
assert!("custom".parse::<ProviderType>().is_err());
}
#[test]
fn test_provider_type_display() {
assert_eq!(ProviderType::OpenAI.to_string(), "openai");
assert_eq!(ProviderType::AzureOpenAI.to_string(), "azure_openai");
assert_eq!(
ProviderType::OpenAICompletions.to_string(),
"openai_completions"
);
assert_eq!(ProviderType::Anthropic.to_string(), "anthropic");
assert_eq!(ProviderType::Gemini.to_string(), "gemini");
}
#[test]
fn test_provider_config_builder() {
let config = ProviderConfig::new(ProviderType::Anthropic)
.with_api_key("test-key")
.with_base_url("https://custom.api.com");
assert_eq!(config.provider_type, ProviderType::Anthropic);
assert_eq!(config.api_key, Some("test-key".to_string()));
assert_eq!(config.base_url, Some("https://custom.api.com".to_string()));
}
#[test]
fn test_driver_registry_requires_api_key() {
let mut registry = DriverRegistry::new();
registry.register(ProviderType::OpenAI, |_api_key, _base_url| {
struct MockDriver;
#[async_trait]
impl LlmDriver for MockDriver {
async fn chat_completion_stream(
&self,
_messages: Vec<LlmMessage>,
_config: &LlmCallConfig,
) -> Result<LlmResponseStream> {
unimplemented!()
}
}
Box::new(MockDriver)
});
let config = ProviderConfig::new(ProviderType::OpenAI);
let result = registry.create_driver(&config);
assert!(result.is_err());
let config_with_key = ProviderConfig::new(ProviderType::OpenAI).with_api_key("test-key");
let result = registry.create_driver(&config_with_key);
assert!(result.is_ok());
}
#[test]
fn test_driver_registry_returns_error_for_unregistered_provider() {
let registry = DriverRegistry::new();
let config = ProviderConfig::new(ProviderType::Anthropic).with_api_key("test-key");
let result = registry.create_driver(&config);
if let Err(AgentLoopError::DriverNotRegistered(provider)) = result {
assert_eq!(provider, "anthropic");
} else {
panic!("Expected DriverNotRegistered error");
}
}
#[test]
fn test_driver_registry_registration() {
let mut registry = DriverRegistry::new();
assert!(!registry.has_driver(&ProviderType::OpenAI));
assert!(!registry.has_driver(&ProviderType::Anthropic));
registry.register(ProviderType::OpenAI, |_, _| {
struct MockDriver;
#[async_trait]
impl LlmDriver for MockDriver {
async fn chat_completion_stream(
&self,
_messages: Vec<LlmMessage>,
_config: &LlmCallConfig,
) -> Result<LlmResponseStream> {
unimplemented!()
}
}
Box::new(MockDriver)
});
assert!(registry.has_driver(&ProviderType::OpenAI));
assert!(!registry.has_driver(&ProviderType::Anthropic));
}
use crate::{ContentPart, ImageFileContentPart, Message, MessageRole, TextContentPart};
#[test]
fn test_message_has_image_files_with_image_file() {
let message = Message {
id: uuid::Uuid::new_v4().into(),
role: MessageRole::User,
content: vec![
ContentPart::Text(TextContentPart {
text: "Look at this image".to_string(),
}),
ContentPart::ImageFile(ImageFileContentPart {
image_id: uuid::Uuid::new_v4().into(),
filename: Some("test.png".to_string()),
}),
],
phase: None,
thinking: None,
thinking_signature: None,
controls: None,
metadata: None,
external_actor: None,
created_at: chrono::Utc::now(),
};
assert!(LlmMessage::message_has_image_files(&message));
}
#[test]
fn test_message_has_image_files_without_image_file() {
let message = Message {
id: uuid::Uuid::new_v4().into(),
role: MessageRole::User,
content: vec![ContentPart::Text(TextContentPart {
text: "Just text".to_string(),
})],
phase: None,
thinking: None,
thinking_signature: None,
controls: None,
metadata: None,
external_actor: None,
created_at: chrono::Utc::now(),
};
assert!(!LlmMessage::message_has_image_files(&message));
}
#[test]
fn test_extract_image_file_ids() {
let id1 = uuid::Uuid::new_v4();
let id2 = uuid::Uuid::new_v4();
let message = Message {
id: uuid::Uuid::new_v4().into(),
role: MessageRole::User,
content: vec![
ContentPart::Text(TextContentPart {
text: "Look at these images".to_string(),
}),
ContentPart::ImageFile(ImageFileContentPart {
image_id: id1.into(),
filename: Some("test1.png".to_string()),
}),
ContentPart::ImageFile(ImageFileContentPart {
image_id: id2.into(),
filename: Some("test2.png".to_string()),
}),
],
phase: None,
thinking: None,
thinking_signature: None,
controls: None,
metadata: None,
external_actor: None,
created_at: chrono::Utc::now(),
};
let ids = LlmMessage::extract_image_file_ids(&message);
assert_eq!(ids.len(), 2);
assert!(ids.contains(&id1));
assert!(ids.contains(&id2));
}
#[test]
fn test_from_message_with_images_text_only() {
let message = Message {
id: uuid::Uuid::new_v4().into(),
role: MessageRole::User,
content: vec![ContentPart::Text(TextContentPart {
text: "Hello".to_string(),
})],
phase: None,
thinking: None,
thinking_signature: None,
controls: None,
metadata: None,
external_actor: None,
created_at: chrono::Utc::now(),
};
let resolved = std::collections::HashMap::new();
let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
assert_eq!(llm_message.role, LlmMessageRole::User);
match llm_message.content {
LlmMessageContent::Text(text) => assert_eq!(text, "Hello"),
_ => panic!("Expected text content"),
}
}
#[test]
fn test_from_message_with_images_resolved_image() {
let image_id = uuid::Uuid::new_v4();
let message = Message {
id: uuid::Uuid::new_v4().into(),
role: MessageRole::User,
content: vec![
ContentPart::Text(TextContentPart {
text: "Look at this".to_string(),
}),
ContentPart::ImageFile(ImageFileContentPart {
image_id: image_id.into(),
filename: Some("test.png".to_string()),
}),
],
phase: None,
thinking: None,
thinking_signature: None,
controls: None,
metadata: None,
external_actor: None,
created_at: chrono::Utc::now(),
};
let mut resolved = std::collections::HashMap::new();
resolved.insert(
image_id,
crate::ResolvedImage::new("base64data", "image/png"),
);
let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
match &llm_message.content {
LlmMessageContent::Parts(parts) => {
assert_eq!(parts.len(), 2);
assert!(matches!(&parts[0], LlmContentPart::Text { .. }));
if let LlmContentPart::Image { url } = &parts[1] {
assert!(url.starts_with("data:image/png;base64,"));
} else {
panic!("Expected image content part");
}
}
_ => panic!("Expected parts content"),
}
}
#[test]
fn test_from_message_with_images_unresolved_image() {
let image_id = uuid::Uuid::new_v4();
let message = Message {
id: uuid::Uuid::new_v4().into(),
role: MessageRole::User,
content: vec![ContentPart::ImageFile(ImageFileContentPart {
image_id: image_id.into(),
filename: Some("missing.png".to_string()),
})],
phase: None,
thinking: None,
thinking_signature: None,
controls: None,
metadata: None,
external_actor: None,
created_at: chrono::Utc::now(),
};
let resolved = std::collections::HashMap::new();
let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
match &llm_message.content {
LlmMessageContent::Text(text) => {
assert!(text.contains("Image not found"));
}
LlmMessageContent::Parts(parts) => {
assert_eq!(parts.len(), 1);
if let LlmContentPart::Text { text } = &parts[0] {
assert!(text.contains("Image not found"));
} else {
panic!("Expected text placeholder for missing image");
}
}
}
}
#[test]
fn test_prepend_text_prefix_simple_text() {
let mut msg = LlmMessage::text(LlmMessageRole::User, "Hello bot");
msg.prepend_text_prefix("[Alice] ");
assert_eq!(msg.content_as_text(), "[Alice] Hello bot");
}
#[test]
fn test_prepend_text_prefix_parts() {
let mut msg = LlmMessage::parts(
LlmMessageRole::User,
vec![
LlmContentPart::Text {
text: "Hello".to_string(),
},
LlmContentPart::Image {
url: "data:image/png;base64,abc".to_string(),
},
],
);
msg.prepend_text_prefix("[Bob] ");
match &msg.content {
LlmMessageContent::Parts(parts) => {
if let LlmContentPart::Text { text } = &parts[0] {
assert_eq!(text, "[Bob] Hello");
} else {
panic!("Expected text part");
}
}
_ => panic!("Expected parts content"),
}
}
#[test]
fn test_prepend_text_prefix_parts_no_text() {
let mut msg = LlmMessage::parts(
LlmMessageRole::User,
vec![LlmContentPart::Image {
url: "data:image/png;base64,abc".to_string(),
}],
);
msg.prepend_text_prefix("[Eve] ");
match &msg.content {
LlmMessageContent::Parts(parts) => {
assert_eq!(parts.len(), 2);
if let LlmContentPart::Text { text } = &parts[0] {
assert_eq!(text, "[Eve] ");
} else {
panic!("Expected prepended text part");
}
}
_ => panic!("Expected parts content"),
}
}
}