1use crate::error::{AgentLoopError, Result};
18use crate::openresponses_protocol::{CompactRequest, CompactResponse};
19use crate::runtime_agent::RuntimeAgent;
20use crate::tool_types::{ToolCall, ToolDefinition};
21use async_trait::async_trait;
22use chrono::{DateTime, Utc};
23use futures::Stream;
24use std::collections::HashMap;
25use std::pin::Pin;
26use std::sync::Arc;
27
28pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>;
34
35#[derive(Debug, Clone)]
37pub enum LlmStreamEvent {
38 TextDelta(String),
40 ThinkingDelta(String),
42 ThinkingSignature(String),
45 ReasonItem {
51 provider: String,
53 model: Option<String>,
55 item_id: String,
57 encrypted_content: Option<String>,
59 summary: Vec<String>,
61 token_count: Option<u32>,
63 },
64 ToolCalls(Vec<ToolCall>),
66 Done(Box<LlmCompletionMetadata>),
68 Error(String),
70}
71
72#[derive(Debug, Clone)]
83pub struct DiscoveredModel {
84 pub model_id: String,
86 pub display_name: Option<String>,
88 pub created_at: Option<DateTime<Utc>>,
90 pub owned_by: Option<String>,
92 pub discovered_profile: Option<crate::llm_models::LlmModelProfile>,
95}
96
97#[derive(Debug, Clone, Default)]
105pub struct LlmCompletionMetadata {
106 pub total_tokens: Option<u32>,
108 pub prompt_tokens: Option<u32>,
110 pub completion_tokens: Option<u32>,
112 pub cache_read_tokens: Option<u32>,
114 pub cache_creation_tokens: Option<u32>,
116 pub provider_cost_usd: Option<f64>,
120 pub model: Option<String>,
122 pub finish_reason: Option<String>,
124 pub retry_metadata: Option<crate::llm_retry::RetryMetadata>,
126 pub response_id: Option<String>,
129 pub phase: Option<String>,
133}
134
135#[async_trait]
139pub trait LlmDriver: Send + Sync {
140 async fn chat_completion_stream(
142 &self,
143 messages: Vec<LlmMessage>,
144 config: &LlmCallConfig,
145 ) -> Result<LlmResponseStream>;
146
147 async fn chat_completion(
149 &self,
150 messages: Vec<LlmMessage>,
151 config: &LlmCallConfig,
152 ) -> Result<LlmResponse> {
153 use futures::StreamExt;
154
155 let mut stream = self.chat_completion_stream(messages, config).await?;
156 let mut text = String::new();
157 let mut thinking = String::new();
158 let mut thinking_signature: Option<String> = None;
159 let mut tool_calls = Vec::new();
160 let mut metadata = LlmCompletionMetadata::default();
161
162 while let Some(event) = stream.next().await {
163 match event? {
164 LlmStreamEvent::TextDelta(delta) => text.push_str(&delta),
165 LlmStreamEvent::ThinkingDelta(delta) => thinking.push_str(&delta),
166 LlmStreamEvent::ThinkingSignature(sig) => thinking_signature = Some(sig),
167 LlmStreamEvent::ReasonItem {
168 encrypted_content, ..
169 } => {
170 if let Some(sig) = encrypted_content {
171 thinking_signature = Some(sig);
172 }
173 }
174 LlmStreamEvent::ToolCalls(calls) => tool_calls = calls,
175 LlmStreamEvent::Done(meta) => metadata = *meta,
176 LlmStreamEvent::Error(err) => return Err(crate::error::AgentLoopError::llm(err)),
177 }
178 }
179
180 Ok(LlmResponse {
181 text,
182 thinking: if thinking.is_empty() {
183 None
184 } else {
185 Some(thinking)
186 },
187 thinking_signature,
188 tool_calls: if tool_calls.is_empty() {
189 None
190 } else {
191 Some(tool_calls)
192 },
193 metadata,
194 })
195 }
196
197 async fn list_models(&self) -> Result<Option<Vec<DiscoveredModel>>> {
205 Ok(None)
207 }
208
209 fn supports_compact(&self) -> bool {
218 false
220 }
221
222 async fn compact(&self, _request: CompactRequest) -> Result<Option<CompactResponse>> {
242 Ok(None)
244 }
245}
246
247#[async_trait]
249impl LlmDriver for Box<dyn LlmDriver> {
250 async fn chat_completion_stream(
251 &self,
252 messages: Vec<LlmMessage>,
253 config: &LlmCallConfig,
254 ) -> Result<LlmResponseStream> {
255 (**self).chat_completion_stream(messages, config).await
256 }
257
258 async fn chat_completion(
259 &self,
260 messages: Vec<LlmMessage>,
261 config: &LlmCallConfig,
262 ) -> Result<LlmResponse> {
263 (**self).chat_completion(messages, config).await
264 }
265
266 async fn list_models(&self) -> Result<Option<Vec<DiscoveredModel>>> {
267 (**self).list_models().await
268 }
269
270 fn supports_compact(&self) -> bool {
271 (**self).supports_compact()
272 }
273
274 async fn compact(&self, request: CompactRequest) -> Result<Option<CompactResponse>> {
275 (**self).compact(request).await
276 }
277}
278
279#[derive(Debug, Clone)]
285pub struct LlmMessage {
286 pub role: LlmMessageRole,
287 pub content: LlmMessageContent,
288 pub tool_calls: Option<Vec<ToolCall>>,
289 pub tool_call_id: Option<String>,
290 pub phase: Option<crate::message::ExecutionPhase>,
295 pub thinking: Option<String>,
298 pub thinking_signature: Option<String>,
301}
302
303impl LlmMessage {
304 pub fn text(role: LlmMessageRole, content: impl Into<String>) -> Self {
306 Self {
307 role,
308 content: LlmMessageContent::Text(content.into()),
309 tool_calls: None,
310 tool_call_id: None,
311 phase: None,
312 thinking: None,
313 thinking_signature: None,
314 }
315 }
316
317 pub fn parts(role: LlmMessageRole, parts: Vec<LlmContentPart>) -> Self {
319 Self {
320 role,
321 content: LlmMessageContent::Parts(parts),
322 tool_calls: None,
323 tool_call_id: None,
324 phase: None,
325 thinking: None,
326 thinking_signature: None,
327 }
328 }
329
330 pub fn content_as_text(&self) -> String {
332 self.content.to_text()
333 }
334
335 pub fn prepend_text_prefix(&mut self, prefix: &str) {
340 match &mut self.content {
341 LlmMessageContent::Text(text) => {
342 *text = format!("{}{}", prefix, text);
343 }
344 LlmMessageContent::Parts(parts) => {
345 for part in parts.iter_mut() {
346 if let LlmContentPart::Text { text } = part {
347 *text = format!("{}{}", prefix, text);
348 return;
349 }
350 }
351 parts.insert(
353 0,
354 LlmContentPart::Text {
355 text: prefix.to_string(),
356 },
357 );
358 }
359 }
360 }
361}
362
363#[derive(Debug, Clone)]
365pub enum LlmMessageContent {
366 Text(String),
368 Parts(Vec<LlmContentPart>),
370}
371
372impl LlmMessageContent {
373 pub fn to_text(&self) -> String {
375 match self {
376 LlmMessageContent::Text(s) => s.clone(),
377 LlmMessageContent::Parts(parts) => parts
378 .iter()
379 .filter_map(|p| match p {
380 LlmContentPart::Text { text } => Some(text.clone()),
381 _ => None,
382 })
383 .collect::<Vec<_>>()
384 .join(""),
385 }
386 }
387
388 pub fn is_text(&self) -> bool {
390 matches!(self, LlmMessageContent::Text(_))
391 }
392
393 pub fn is_parts(&self) -> bool {
395 matches!(self, LlmMessageContent::Parts(_))
396 }
397}
398
399impl From<String> for LlmMessageContent {
400 fn from(s: String) -> Self {
401 LlmMessageContent::Text(s)
402 }
403}
404
405impl From<&str> for LlmMessageContent {
406 fn from(s: &str) -> Self {
407 LlmMessageContent::Text(s.to_string())
408 }
409}
410
411#[derive(Debug, Clone)]
413pub enum LlmContentPart {
414 Text { text: String },
416 Image { url: String },
418 Audio { url: String },
420}
421
422impl LlmContentPart {
423 pub fn text(text: impl Into<String>) -> Self {
425 LlmContentPart::Text { text: text.into() }
426 }
427
428 pub fn image(url: impl Into<String>) -> Self {
430 LlmContentPart::Image { url: url.into() }
431 }
432
433 pub fn audio(url: impl Into<String>) -> Self {
435 LlmContentPart::Audio { url: url.into() }
436 }
437}
438
439#[derive(Debug, Clone, PartialEq, Eq)]
441pub enum LlmMessageRole {
442 System,
443 User,
444 Assistant,
445 Tool,
446}
447
448#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
458pub struct ToolSearchConfig {
459 pub enabled: bool,
461 pub threshold: usize,
464}
465
466#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
468#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
469#[serde(rename_all = "snake_case")]
470pub enum PromptCacheStrategy {
471 #[default]
473 Auto,
474}
475
476#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
481#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
482pub struct PromptCacheConfig {
483 pub enabled: bool,
485 #[serde(default)]
487 pub strategy: PromptCacheStrategy,
488 #[serde(default, skip_serializing_if = "Option::is_none")]
495 pub gemini_cached_content: Option<String>,
496}
497
498#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
503#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
504pub struct OpenRouterRoutingConfig {
505 #[serde(default, skip_serializing_if = "Vec::is_empty")]
507 pub models: Vec<String>,
508 #[serde(default, skip_serializing_if = "Option::is_none")]
511 pub route: Option<OpenRouterRoute>,
512 #[serde(default, skip_serializing_if = "Option::is_none")]
514 pub provider: Option<OpenRouterProviderRouting>,
515}
516
517impl OpenRouterRoutingConfig {
518 pub fn is_empty(&self) -> bool {
519 self.models.is_empty() && self.route.is_none() && self.provider.is_none()
520 }
521
522 pub fn fallback_models(models: impl IntoIterator<Item = impl Into<String>>) -> Self {
524 let models = models.into_iter().map(Into::into).collect::<Vec<_>>();
525 let route = (!models.is_empty()).then_some(OpenRouterRoute::Fallback);
526 Self {
527 models,
528 route,
529 provider: None,
530 }
531 }
532
533 pub fn validate_for_primary_model(
534 &self,
535 primary_model: &str,
536 ) -> std::result::Result<(), String> {
537 if self.route == Some(OpenRouterRoute::Fallback) && self.models.is_empty() {
538 return Err(
539 "OpenRouter fallback routing requires at least one model in `models`".to_string(),
540 );
541 }
542
543 if let Some(first_model) = self.models.first()
544 && first_model != primary_model
545 {
546 return Err(format!(
547 "OpenRouter routing models[0] ('{first_model}') must match primary model ('{primary_model}')"
548 ));
549 }
550
551 Ok(())
552 }
553}
554
555#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
557#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
558#[serde(rename_all = "snake_case")]
559pub enum OpenRouterRoute {
560 Fallback,
561}
562
563#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
565#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
566pub struct OpenRouterProviderRouting {
567 #[serde(default, skip_serializing_if = "Vec::is_empty")]
569 pub order: Vec<String>,
570 #[serde(default, skip_serializing_if = "Vec::is_empty")]
572 pub only: Vec<String>,
573 #[serde(default, skip_serializing_if = "Vec::is_empty")]
575 pub ignore: Vec<String>,
576 #[serde(default, skip_serializing_if = "Option::is_none")]
578 pub allow_fallbacks: Option<bool>,
579 #[serde(default, skip_serializing_if = "Option::is_none")]
581 pub require_parameters: Option<bool>,
582 #[serde(default, skip_serializing_if = "Option::is_none")]
584 pub data_collection: Option<OpenRouterDataCollection>,
585 #[serde(default, skip_serializing_if = "Option::is_none")]
587 pub zdr: Option<bool>,
588 #[serde(default, skip_serializing_if = "Option::is_none")]
590 pub enforce_distillable_text: Option<bool>,
591 #[serde(default, skip_serializing_if = "Vec::is_empty")]
593 pub quantizations: Vec<String>,
594 #[serde(default, skip_serializing_if = "Option::is_none")]
596 pub sort: Option<OpenRouterProviderSort>,
597 #[serde(default, skip_serializing_if = "Option::is_none")]
599 pub max_price: Option<OpenRouterMaxPrice>,
600}
601
602impl OpenRouterProviderRouting {
603 pub fn is_empty(&self) -> bool {
604 self.order.is_empty()
605 && self.only.is_empty()
606 && self.ignore.is_empty()
607 && self.allow_fallbacks.is_none()
608 && self.require_parameters.is_none()
609 && self.data_collection.is_none()
610 && self.zdr.is_none()
611 && self.enforce_distillable_text.is_none()
612 && self.quantizations.is_empty()
613 && self.sort.is_none()
614 && self.max_price.is_none()
615 }
616}
617
618#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
620#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
621#[serde(rename_all = "snake_case")]
622pub enum OpenRouterDataCollection {
623 Allow,
624 Deny,
625}
626
627#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
629#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
630#[serde(untagged)]
631pub enum OpenRouterProviderSort {
632 Simple(OpenRouterProviderSortBy),
633 Advanced(OpenRouterProviderSortOptions),
634}
635
636#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
638#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
639#[serde(rename_all = "snake_case")]
640pub enum OpenRouterProviderSortBy {
641 Price,
642 Throughput,
643 Latency,
644}
645
646#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
648#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
649pub struct OpenRouterProviderSortOptions {
650 pub by: OpenRouterProviderSortBy,
651 #[serde(default, skip_serializing_if = "Option::is_none")]
652 pub partition: Option<OpenRouterSortPartition>,
653}
654
655#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
657#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
658#[serde(rename_all = "snake_case")]
659pub enum OpenRouterSortPartition {
660 Model,
661 None,
662}
663
664#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
667#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
668pub struct OpenRouterMaxPrice {
669 #[serde(default, skip_serializing_if = "Option::is_none")]
670 pub prompt: Option<f64>,
671 #[serde(default, skip_serializing_if = "Option::is_none")]
672 pub completion: Option<f64>,
673 #[serde(default, skip_serializing_if = "Option::is_none")]
674 pub request: Option<f64>,
675 #[serde(default, skip_serializing_if = "Option::is_none")]
676 pub image: Option<f64>,
677}
678
679#[derive(Debug, Clone)]
681pub struct LlmCallConfig {
682 pub model: String,
683 pub temperature: Option<f32>,
684 pub max_tokens: Option<u32>,
685 pub tools: Vec<ToolDefinition>,
686 pub reasoning_effort: Option<String>,
688 pub metadata: HashMap<String, String>,
692 pub previous_response_id: Option<String>,
695 pub tool_search: Option<ToolSearchConfig>,
697 pub prompt_cache: Option<PromptCacheConfig>,
699 pub openrouter_routing: Option<OpenRouterRoutingConfig>,
701}
702
703impl From<&RuntimeAgent> for LlmCallConfig {
704 fn from(runtime_agent: &RuntimeAgent) -> Self {
705 Self {
706 model: runtime_agent.model.clone(),
707 temperature: runtime_agent.temperature,
708 max_tokens: runtime_agent.max_tokens,
709 tools: runtime_agent.tools.clone(),
710 reasoning_effort: None, metadata: HashMap::new(), previous_response_id: None,
713 tool_search: runtime_agent.tool_search.clone(),
714 prompt_cache: runtime_agent.prompt_cache.clone(),
715 openrouter_routing: None,
716 }
717 }
718}
719
720#[derive(Debug, Clone)]
722pub struct LlmResponse {
723 pub text: String,
724 pub thinking: Option<String>,
726 pub thinking_signature: Option<String>,
728 pub tool_calls: Option<Vec<ToolCall>>,
729 pub metadata: LlmCompletionMetadata,
730}
731
732pub struct LlmCallConfigBuilder {
751 config: LlmCallConfig,
752}
753
754impl LlmCallConfigBuilder {
755 pub fn from(runtime_agent: &RuntimeAgent) -> Self {
757 Self {
758 config: LlmCallConfig::from(runtime_agent),
759 }
760 }
761
762 pub fn reasoning_effort(mut self, effort: impl Into<String>) -> Self {
764 self.config.reasoning_effort = Some(effort.into());
765 self
766 }
767
768 pub fn model(mut self, model: impl Into<String>) -> Self {
770 self.config.model = model.into();
771 self
772 }
773
774 pub fn temperature(mut self, temp: f32) -> Self {
776 self.config.temperature = Some(temp);
777 self
778 }
779
780 pub fn max_tokens(mut self, tokens: u32) -> Self {
782 self.config.max_tokens = Some(tokens);
783 self
784 }
785
786 pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
788 self.config.tools = tools;
789 self
790 }
791
792 pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
797 self.config.metadata = metadata;
798 self
799 }
800
801 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
803 self.config.metadata.insert(key.into(), value.into());
804 self
805 }
806
807 pub fn previous_response_id(mut self, id: Option<String>) -> Self {
809 self.config.previous_response_id = id;
810 self
811 }
812
813 pub fn tool_search(mut self, config: ToolSearchConfig) -> Self {
815 self.config.tool_search = Some(config);
816 self
817 }
818
819 pub fn prompt_cache(mut self, config: PromptCacheConfig) -> Self {
821 self.config.prompt_cache = Some(config);
822 self
823 }
824
825 pub fn openrouter_routing(mut self, config: OpenRouterRoutingConfig) -> Self {
827 self.config.openrouter_routing = (!config.is_empty()).then_some(config);
828 self
829 }
830
831 pub fn build(self) -> LlmCallConfig {
833 self.config
834 }
835}
836
837impl From<&crate::message::Message> for LlmMessage {
842 fn from(msg: &crate::message::Message) -> Self {
848 let role = match msg.role {
849 crate::message::MessageRole::System => LlmMessageRole::System,
850 crate::message::MessageRole::User => LlmMessageRole::User,
851 crate::message::MessageRole::Agent => LlmMessageRole::Assistant,
852 crate::message::MessageRole::ToolResult => LlmMessageRole::Tool,
853 };
854
855 let tool_calls: Vec<ToolCall> = msg
857 .tool_calls()
858 .into_iter()
859 .map(|tc| ToolCall {
860 id: tc.id.clone(),
861 name: tc.name.clone(),
862 arguments: tc.arguments.clone(),
863 })
864 .collect();
865
866 LlmMessage {
867 role,
868 content: LlmMessageContent::Text(msg.content_to_llm_string()),
869 tool_calls: if tool_calls.is_empty() {
870 None
871 } else {
872 Some(tool_calls)
873 },
874 tool_call_id: msg.tool_call_id().map(|s| s.to_string()),
875 phase: msg.phase,
876 thinking: msg.thinking.clone(),
877 thinking_signature: msg.thinking_signature.clone(),
878 }
879 }
880}
881
882use crate::traits::ResolvedImage;
887use uuid::Uuid;
888
889impl LlmMessage {
890 pub fn from_message_with_images(
910 msg: &crate::message::Message,
911 resolved_images: &HashMap<Uuid, ResolvedImage>,
912 ) -> Self {
913 use crate::message::{ContentPart, MessageRole};
914
915 let role = match msg.role {
916 MessageRole::System => LlmMessageRole::System,
917 MessageRole::User => LlmMessageRole::User,
918 MessageRole::Agent => LlmMessageRole::Assistant,
919 MessageRole::ToolResult => LlmMessageRole::Tool,
920 };
921
922 let mut parts: Vec<LlmContentPart> = Vec::new();
924 let mut tool_calls: Vec<ToolCall> = Vec::new();
925
926 for part in &msg.content {
927 match part {
928 ContentPart::Text(t) => {
929 parts.push(LlmContentPart::Text {
930 text: t.text.clone(),
931 });
932 }
933 ContentPart::Image(img) => {
934 if let Some(url) = &img.url {
936 parts.push(LlmContentPart::Image { url: url.clone() });
937 } else if let (Some(base64), Some(media_type)) = (&img.base64, &img.media_type)
938 {
939 let data_url = format!("data:{};base64,{}", media_type, base64);
940 parts.push(LlmContentPart::Image { url: data_url });
941 }
942 }
943 ContentPart::ImageFile(img_file) => {
944 if let Some(resolved) = resolved_images.get(&img_file.image_id.uuid()) {
946 parts.push(LlmContentPart::Image {
947 url: resolved.to_data_url(),
948 });
949 } else {
950 parts.push(LlmContentPart::Text {
952 text: format!("[Image not found: {}]", img_file.image_id),
953 });
954 }
955 }
956 ContentPart::ToolCall(tc) => {
957 tool_calls.push(ToolCall {
959 id: tc.id.clone(),
960 name: tc.name.clone(),
961 arguments: tc.arguments.clone(),
962 });
963 }
964 ContentPart::ToolResult(tr) => {
965 let text = if let Some(err) = &tr.error {
967 format!("Tool error: {}", err)
968 } else if let Some(res) = &tr.result {
969 serde_json::to_string(res).unwrap_or_else(|_| "{}".to_string())
970 } else {
971 "{}".to_string()
972 };
973 let text = truncate_tool_result(text);
977 parts.push(LlmContentPart::Text { text });
978 }
979 }
980 }
981
982 let content = if parts.len() == 1 && matches!(&parts[0], LlmContentPart::Text { .. }) {
984 if let LlmContentPart::Text { text } = &parts[0] {
986 LlmMessageContent::Text(text.clone())
987 } else {
988 LlmMessageContent::Parts(parts)
989 }
990 } else if parts.is_empty() {
991 LlmMessageContent::Text(String::new())
993 } else {
994 LlmMessageContent::Parts(parts)
996 };
997
998 LlmMessage {
999 role,
1000 content,
1001 tool_calls: if tool_calls.is_empty() {
1002 None
1003 } else {
1004 Some(tool_calls)
1005 },
1006 tool_call_id: msg.tool_call_id().map(|s| s.to_string()),
1007 phase: msg.phase,
1008 thinking: msg.thinking.clone(),
1009 thinking_signature: msg.thinking_signature.clone(),
1010 }
1011 }
1012
1013 pub fn message_has_image_files(msg: &crate::message::Message) -> bool {
1015 msg.content.iter().any(|p| p.is_image_file())
1016 }
1017
1018 pub fn extract_image_file_ids(msg: &crate::message::Message) -> Vec<Uuid> {
1020 msg.content
1021 .iter()
1022 .filter_map(|p| match p {
1023 crate::message::ContentPart::ImageFile(f) => Some(f.image_id.uuid()),
1024 _ => None,
1025 })
1026 .collect()
1027 }
1028}
1029
1030#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1036pub enum ProviderType {
1037 OpenAI,
1040 OpenRouter,
1042 AzureOpenAI,
1044 OpenAICompletions,
1047 Anthropic,
1048 Gemini,
1050 LlmSim,
1052 Bedrock,
1054}
1055
1056impl std::str::FromStr for ProviderType {
1057 type Err = String;
1058
1059 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
1060 match s.to_lowercase().as_str() {
1061 "openai" => Ok(ProviderType::OpenAI),
1062 "openrouter" => Ok(ProviderType::OpenRouter),
1063 "azure_openai" => Ok(ProviderType::AzureOpenAI),
1064 "openai_completions" => Ok(ProviderType::OpenAICompletions),
1065 "anthropic" => Ok(ProviderType::Anthropic),
1066 "gemini" => Ok(ProviderType::Gemini),
1067 "llmsim" => Ok(ProviderType::LlmSim),
1068 "bedrock" => Ok(ProviderType::Bedrock),
1069 _ => Err(format!("Unknown provider type: {}", s)),
1070 }
1071 }
1072}
1073
1074impl std::fmt::Display for ProviderType {
1075 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1076 match self {
1077 ProviderType::OpenAI => write!(f, "openai"),
1078 ProviderType::OpenRouter => write!(f, "openrouter"),
1079 ProviderType::AzureOpenAI => write!(f, "azure_openai"),
1080 ProviderType::OpenAICompletions => write!(f, "openai_completions"),
1081 ProviderType::Anthropic => write!(f, "anthropic"),
1082 ProviderType::Gemini => write!(f, "gemini"),
1083 ProviderType::LlmSim => write!(f, "llmsim"),
1084 ProviderType::Bedrock => write!(f, "bedrock"),
1085 }
1086 }
1087}
1088
1089#[derive(Debug, Clone)]
1091pub struct ProviderConfig {
1092 pub provider_type: ProviderType,
1094 pub api_key: Option<String>,
1096 pub base_url: Option<String>,
1098}
1099
1100impl ProviderConfig {
1101 pub fn new(provider_type: ProviderType) -> Self {
1103 Self {
1104 provider_type,
1105 api_key: None,
1106 base_url: None,
1107 }
1108 }
1109
1110 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
1112 self.api_key = Some(api_key.into());
1113 self
1114 }
1115
1116 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1118 self.base_url = Some(base_url.into());
1119 self
1120 }
1121}
1122
1123impl From<crate::llm_models::LlmProviderType> for ProviderType {
1124 fn from(provider_type: crate::llm_models::LlmProviderType) -> Self {
1125 use crate::llm_models::LlmProviderType;
1126 match provider_type {
1127 LlmProviderType::Openai => ProviderType::OpenAI,
1128 LlmProviderType::Openrouter => ProviderType::OpenRouter,
1129 LlmProviderType::AzureOpenai => ProviderType::AzureOpenAI,
1130 LlmProviderType::OpenaiCompletions => ProviderType::OpenAICompletions,
1131 LlmProviderType::Anthropic => ProviderType::Anthropic,
1132 LlmProviderType::Gemini => ProviderType::Gemini,
1133 LlmProviderType::LlmSim => ProviderType::LlmSim,
1134 LlmProviderType::Bedrock => ProviderType::Bedrock,
1135 }
1136 }
1137}
1138
1139impl From<&crate::traits::ModelWithProvider> for ProviderConfig {
1140 fn from(model: &crate::traits::ModelWithProvider) -> Self {
1141 Self {
1142 provider_type: model.provider_type.clone().into(),
1143 api_key: model.api_key.clone(),
1144 base_url: model.base_url.clone(),
1145 }
1146 }
1147}
1148
1149pub type BoxedLlmDriver = Box<dyn LlmDriver>;
1151
1152pub type DriverFactory = Arc<dyn Fn(&str, Option<&str>) -> BoxedLlmDriver + Send + Sync>;
1160
1161#[derive(Clone, Default)]
1181pub struct DriverRegistry {
1182 factories: HashMap<ProviderType, DriverFactory>,
1183}
1184
1185impl DriverRegistry {
1186 pub fn new() -> Self {
1188 Self {
1189 factories: HashMap::new(),
1190 }
1191 }
1192
1193 pub fn register<F>(&mut self, provider_type: ProviderType, factory: F)
1195 where
1196 F: Fn(&str, Option<&str>) -> BoxedLlmDriver + Send + Sync + 'static,
1197 {
1198 self.factories.insert(provider_type, Arc::new(factory));
1199 }
1200
1201 pub fn create_driver(&self, config: &ProviderConfig) -> Result<BoxedLlmDriver> {
1209 let api_key = if config.provider_type == ProviderType::LlmSim {
1211 config.api_key.as_deref().unwrap_or("")
1213 } else {
1214 config.api_key.as_ref().ok_or_else(|| {
1215 AgentLoopError::llm(
1216 "API key is required. Configure the API key in provider settings.",
1217 )
1218 })?
1219 };
1220
1221 let factory = self.factories.get(&config.provider_type).ok_or_else(|| {
1223 AgentLoopError::driver_not_registered(config.provider_type.to_string())
1224 })?;
1225
1226 Ok(factory(api_key, config.base_url.as_deref()))
1228 }
1229
1230 pub fn has_driver(&self, provider_type: &ProviderType) -> bool {
1232 self.factories.contains_key(provider_type)
1233 }
1234
1235 pub fn registered_providers(&self) -> Vec<ProviderType> {
1237 self.factories.keys().cloned().collect()
1238 }
1239}
1240
1241const MAX_TOOL_RESULT_BYTES: usize = 64 * 1024;
1246
1247const TRUNCATION_SUFFIX: &str =
1248 "\n\n[Output truncated — exceeded 64 KiB limit. Try quiet flags, pipes, or redirect to file.]";
1249
1250fn truncate_tool_result(text: String) -> String {
1251 if text.len() <= MAX_TOOL_RESULT_BYTES {
1252 return text;
1253 }
1254 let content_budget = MAX_TOOL_RESULT_BYTES.saturating_sub(TRUNCATION_SUFFIX.len());
1255 let mut end = content_budget;
1256 while end > 0 && !text.is_char_boundary(end) {
1257 end -= 1;
1258 }
1259 let mut truncated = text[..end].to_string();
1260 truncated.push_str(TRUNCATION_SUFFIX);
1261 truncated
1262}
1263
1264#[cfg(test)]
1269mod tests {
1270 use super::*;
1271
1272 #[test]
1273 fn test_llm_call_config_builder_from_runtime_agent() {
1274 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
1275 let llm_config = LlmCallConfigBuilder::from(&runtime_agent).build();
1276
1277 assert_eq!(llm_config.model, "gpt-4o");
1278 assert!(llm_config.reasoning_effort.is_none());
1279 assert!(llm_config.temperature.is_none());
1280 assert!(llm_config.max_tokens.is_none());
1281 assert!(llm_config.tools.is_empty());
1282 assert!(llm_config.metadata.is_empty());
1283 }
1284
1285 #[test]
1286 fn test_llm_call_config_builder_with_metadata() {
1287 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
1288 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
1289 .with_metadata("session_id", "session_abc123")
1290 .with_metadata("agent_id", "agent_xyz789")
1291 .build();
1292
1293 assert_eq!(
1294 llm_config.metadata.get("session_id"),
1295 Some(&"session_abc123".to_string())
1296 );
1297 assert_eq!(
1298 llm_config.metadata.get("agent_id"),
1299 Some(&"agent_xyz789".to_string())
1300 );
1301 }
1302
1303 #[test]
1304 fn test_llm_call_config_builder_with_metadata_hashmap() {
1305 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
1306 let mut metadata = HashMap::new();
1307 metadata.insert("key1".to_string(), "value1".to_string());
1308 metadata.insert("key2".to_string(), "value2".to_string());
1309
1310 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
1311 .metadata(metadata)
1312 .build();
1313
1314 assert_eq!(llm_config.metadata.get("key1"), Some(&"value1".to_string()));
1315 assert_eq!(llm_config.metadata.get("key2"), Some(&"value2".to_string()));
1316 }
1317
1318 #[test]
1319 fn test_llm_call_config_builder_with_reasoning_effort() {
1320 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
1321 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
1322 .reasoning_effort("high")
1323 .build();
1324
1325 assert_eq!(llm_config.reasoning_effort, Some("high".to_string()));
1326 }
1327
1328 #[test]
1329 fn test_llm_call_config_builder_with_all_options() {
1330 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
1331 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
1332 .model("claude-3-opus")
1333 .reasoning_effort("medium")
1334 .temperature(0.7)
1335 .max_tokens(1000)
1336 .build();
1337
1338 assert_eq!(llm_config.model, "claude-3-opus");
1339 assert_eq!(llm_config.reasoning_effort, Some("medium".to_string()));
1340 assert_eq!(llm_config.temperature, Some(0.7));
1341 assert_eq!(llm_config.max_tokens, Some(1000));
1342 }
1343
1344 #[test]
1345 fn test_llm_call_config_builder_with_openrouter_routing() {
1346 let runtime_agent = RuntimeAgent::new("You are helpful", "openai/gpt-5-mini");
1347 let routing = OpenRouterRoutingConfig::fallback_models([
1348 "openai/gpt-5-mini",
1349 "anthropic/claude-sonnet-4.5",
1350 ]);
1351
1352 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
1353 .openrouter_routing(routing.clone())
1354 .build();
1355
1356 assert_eq!(llm_config.openrouter_routing, Some(routing));
1357 }
1358
1359 #[test]
1360 fn test_openrouter_fallback_models_empty_is_empty() {
1361 let routing = OpenRouterRoutingConfig::fallback_models(std::iter::empty::<String>());
1362
1363 assert!(routing.is_empty());
1364 assert_eq!(routing.route, None);
1365 }
1366
1367 #[test]
1368 fn test_openrouter_routing_validates_primary_model() {
1369 let routing = OpenRouterRoutingConfig::fallback_models([
1370 "openai/gpt-5-mini",
1371 "anthropic/claude-sonnet-4.5",
1372 ]);
1373
1374 assert!(
1375 routing
1376 .validate_for_primary_model("openai/gpt-5-mini")
1377 .is_ok()
1378 );
1379 let err = routing
1380 .validate_for_primary_model("anthropic/claude-sonnet-4.5")
1381 .unwrap_err();
1382 assert!(err.contains("models[0]"));
1383 }
1384
1385 #[test]
1386 fn test_openrouter_routing_rejects_fallback_without_models() {
1387 let routing = OpenRouterRoutingConfig {
1388 route: Some(OpenRouterRoute::Fallback),
1389 ..Default::default()
1390 };
1391
1392 let err = routing
1393 .validate_for_primary_model("openai/gpt-5-mini")
1394 .unwrap_err();
1395 assert!(err.contains("requires at least one model"));
1396 }
1397
1398 #[test]
1399 fn test_openrouter_routing_serializes_request_fields() {
1400 let routing = OpenRouterRoutingConfig {
1401 models: vec![
1402 "openai/gpt-5-mini".to_string(),
1403 "anthropic/claude-sonnet-4.5".to_string(),
1404 ],
1405 route: Some(OpenRouterRoute::Fallback),
1406 provider: Some(OpenRouterProviderRouting {
1407 order: vec!["anthropic".to_string(), "openai".to_string()],
1408 allow_fallbacks: Some(false),
1409 require_parameters: Some(true),
1410 data_collection: Some(OpenRouterDataCollection::Deny),
1411 zdr: Some(true),
1412 sort: Some(OpenRouterProviderSort::Advanced(
1413 OpenRouterProviderSortOptions {
1414 by: OpenRouterProviderSortBy::Throughput,
1415 partition: Some(OpenRouterSortPartition::None),
1416 },
1417 )),
1418 max_price: Some(OpenRouterMaxPrice {
1419 prompt: Some(1.0),
1420 completion: Some(2.0),
1421 ..Default::default()
1422 }),
1423 ..Default::default()
1424 }),
1425 };
1426
1427 let json = serde_json::to_value(routing).unwrap();
1428
1429 assert_eq!(
1430 json,
1431 serde_json::json!({
1432 "models": [
1433 "openai/gpt-5-mini",
1434 "anthropic/claude-sonnet-4.5"
1435 ],
1436 "route": "fallback",
1437 "provider": {
1438 "order": ["anthropic", "openai"],
1439 "allow_fallbacks": false,
1440 "require_parameters": true,
1441 "data_collection": "deny",
1442 "zdr": true,
1443 "sort": {
1444 "by": "throughput",
1445 "partition": "none"
1446 },
1447 "max_price": {
1448 "prompt": 1.0,
1449 "completion": 2.0
1450 }
1451 }
1452 })
1453 );
1454 }
1455
1456 #[test]
1457 fn test_provider_type_parsing() {
1458 assert_eq!(
1459 "openai".parse::<ProviderType>().unwrap(),
1460 ProviderType::OpenAI
1461 );
1462 assert_eq!(
1463 "openrouter".parse::<ProviderType>().unwrap(),
1464 ProviderType::OpenRouter
1465 );
1466 assert_eq!(
1467 "openai_completions".parse::<ProviderType>().unwrap(),
1468 ProviderType::OpenAICompletions
1469 );
1470 assert_eq!(
1471 "azure_openai".parse::<ProviderType>().unwrap(),
1472 ProviderType::AzureOpenAI
1473 );
1474 assert_eq!(
1475 "anthropic".parse::<ProviderType>().unwrap(),
1476 ProviderType::Anthropic
1477 );
1478 assert_eq!(
1479 "gemini".parse::<ProviderType>().unwrap(),
1480 ProviderType::Gemini
1481 );
1482 assert!("ollama".parse::<ProviderType>().is_err());
1484 assert!("custom".parse::<ProviderType>().is_err());
1485 }
1486
1487 #[test]
1488 fn test_provider_type_display() {
1489 assert_eq!(ProviderType::OpenAI.to_string(), "openai");
1490 assert_eq!(ProviderType::OpenRouter.to_string(), "openrouter");
1491 assert_eq!(ProviderType::AzureOpenAI.to_string(), "azure_openai");
1492 assert_eq!(
1493 ProviderType::OpenAICompletions.to_string(),
1494 "openai_completions"
1495 );
1496 assert_eq!(ProviderType::Anthropic.to_string(), "anthropic");
1497 assert_eq!(ProviderType::Gemini.to_string(), "gemini");
1498 }
1499
1500 #[test]
1501 fn test_provider_config_builder() {
1502 let config = ProviderConfig::new(ProviderType::Anthropic)
1503 .with_api_key("test-key")
1504 .with_base_url("https://custom.api.com");
1505
1506 assert_eq!(config.provider_type, ProviderType::Anthropic);
1507 assert_eq!(config.api_key, Some("test-key".to_string()));
1508 assert_eq!(config.base_url, Some("https://custom.api.com".to_string()));
1509 }
1510
1511 #[test]
1512 fn test_driver_registry_requires_api_key() {
1513 let mut registry = DriverRegistry::new();
1515 registry.register(ProviderType::OpenAI, |_api_key, _base_url| {
1516 struct MockDriver;
1518 #[async_trait]
1519 impl LlmDriver for MockDriver {
1520 async fn chat_completion_stream(
1521 &self,
1522 _messages: Vec<LlmMessage>,
1523 _config: &LlmCallConfig,
1524 ) -> Result<LlmResponseStream> {
1525 unimplemented!()
1526 }
1527 }
1528 Box::new(MockDriver)
1529 });
1530
1531 let config = ProviderConfig::new(ProviderType::OpenAI);
1533 let result = registry.create_driver(&config);
1534 assert!(result.is_err());
1535
1536 let config_with_key = ProviderConfig::new(ProviderType::OpenAI).with_api_key("test-key");
1538 let result = registry.create_driver(&config_with_key);
1539 assert!(result.is_ok());
1540 }
1541
1542 #[test]
1543 fn test_driver_registry_returns_error_for_unregistered_provider() {
1544 let registry = DriverRegistry::new();
1545 let config = ProviderConfig::new(ProviderType::Anthropic).with_api_key("test-key");
1546
1547 let result = registry.create_driver(&config);
1548
1549 if let Err(AgentLoopError::DriverNotRegistered(provider)) = result {
1551 assert_eq!(provider, "anthropic");
1552 } else {
1553 panic!("Expected DriverNotRegistered error");
1554 }
1555 }
1556
1557 #[test]
1558 fn test_driver_registry_registration() {
1559 let mut registry = DriverRegistry::new();
1560
1561 assert!(!registry.has_driver(&ProviderType::OpenAI));
1562 assert!(!registry.has_driver(&ProviderType::Anthropic));
1563
1564 registry.register(ProviderType::OpenAI, |_, _| {
1565 struct MockDriver;
1566 #[async_trait]
1567 impl LlmDriver for MockDriver {
1568 async fn chat_completion_stream(
1569 &self,
1570 _messages: Vec<LlmMessage>,
1571 _config: &LlmCallConfig,
1572 ) -> Result<LlmResponseStream> {
1573 unimplemented!()
1574 }
1575 }
1576 Box::new(MockDriver)
1577 });
1578
1579 assert!(registry.has_driver(&ProviderType::OpenAI));
1580 assert!(!registry.has_driver(&ProviderType::Anthropic));
1581 }
1582
1583 use crate::{ContentPart, ImageFileContentPart, Message, MessageRole, TextContentPart};
1588
1589 #[test]
1590 fn test_message_has_image_files_with_image_file() {
1591 let message = Message {
1592 id: uuid::Uuid::new_v4().into(),
1593 role: MessageRole::User,
1594 content: vec![
1595 ContentPart::Text(TextContentPart {
1596 text: "Look at this image".to_string(),
1597 }),
1598 ContentPart::ImageFile(ImageFileContentPart {
1599 image_id: uuid::Uuid::new_v4().into(),
1600 filename: Some("test.png".to_string()),
1601 }),
1602 ],
1603 phase: None,
1604 thinking: None,
1605 thinking_signature: None,
1606 controls: None,
1607 metadata: None,
1608 external_actor: None,
1609 created_at: chrono::Utc::now(),
1610 };
1611
1612 assert!(LlmMessage::message_has_image_files(&message));
1613 }
1614
1615 #[test]
1616 fn test_message_has_image_files_without_image_file() {
1617 let message = Message {
1618 id: uuid::Uuid::new_v4().into(),
1619 role: MessageRole::User,
1620 content: vec![ContentPart::Text(TextContentPart {
1621 text: "Just text".to_string(),
1622 })],
1623 phase: None,
1624 thinking: None,
1625 thinking_signature: None,
1626 controls: None,
1627 metadata: None,
1628 external_actor: None,
1629 created_at: chrono::Utc::now(),
1630 };
1631
1632 assert!(!LlmMessage::message_has_image_files(&message));
1633 }
1634
1635 #[test]
1636 fn test_extract_image_file_ids() {
1637 let id1 = uuid::Uuid::new_v4();
1638 let id2 = uuid::Uuid::new_v4();
1639
1640 let message = Message {
1641 id: uuid::Uuid::new_v4().into(),
1642 role: MessageRole::User,
1643 content: vec![
1644 ContentPart::Text(TextContentPart {
1645 text: "Look at these images".to_string(),
1646 }),
1647 ContentPart::ImageFile(ImageFileContentPart {
1648 image_id: id1.into(),
1649 filename: Some("test1.png".to_string()),
1650 }),
1651 ContentPart::ImageFile(ImageFileContentPart {
1652 image_id: id2.into(),
1653 filename: Some("test2.png".to_string()),
1654 }),
1655 ],
1656 phase: None,
1657 thinking: None,
1658 thinking_signature: None,
1659 controls: None,
1660 metadata: None,
1661 external_actor: None,
1662 created_at: chrono::Utc::now(),
1663 };
1664
1665 let ids = LlmMessage::extract_image_file_ids(&message);
1666 assert_eq!(ids.len(), 2);
1667 assert!(ids.contains(&id1));
1668 assert!(ids.contains(&id2));
1669 }
1670
1671 #[test]
1672 fn test_from_message_with_images_text_only() {
1673 let message = Message {
1674 id: uuid::Uuid::new_v4().into(),
1675 role: MessageRole::User,
1676 content: vec![ContentPart::Text(TextContentPart {
1677 text: "Hello".to_string(),
1678 })],
1679 phase: None,
1680 thinking: None,
1681 thinking_signature: None,
1682 controls: None,
1683 metadata: None,
1684 external_actor: None,
1685 created_at: chrono::Utc::now(),
1686 };
1687
1688 let resolved = std::collections::HashMap::new();
1689 let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
1690
1691 assert_eq!(llm_message.role, LlmMessageRole::User);
1692 match llm_message.content {
1693 LlmMessageContent::Text(text) => assert_eq!(text, "Hello"),
1694 _ => panic!("Expected text content"),
1695 }
1696 }
1697
1698 #[test]
1699 fn test_from_message_with_images_resolved_image() {
1700 let image_id = uuid::Uuid::new_v4();
1701 let message = Message {
1702 id: uuid::Uuid::new_v4().into(),
1703 role: MessageRole::User,
1704 content: vec![
1705 ContentPart::Text(TextContentPart {
1706 text: "Look at this".to_string(),
1707 }),
1708 ContentPart::ImageFile(ImageFileContentPart {
1709 image_id: image_id.into(),
1710 filename: Some("test.png".to_string()),
1711 }),
1712 ],
1713 phase: None,
1714 thinking: None,
1715 thinking_signature: None,
1716 controls: None,
1717 metadata: None,
1718 external_actor: None,
1719 created_at: chrono::Utc::now(),
1720 };
1721
1722 let mut resolved = std::collections::HashMap::new();
1723 resolved.insert(
1724 image_id,
1725 crate::ResolvedImage::new("base64data", "image/png"),
1726 );
1727
1728 let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
1729
1730 match &llm_message.content {
1731 LlmMessageContent::Parts(parts) => {
1732 assert_eq!(parts.len(), 2);
1733 assert!(matches!(&parts[0], LlmContentPart::Text { .. }));
1735 if let LlmContentPart::Image { url } = &parts[1] {
1737 assert!(url.starts_with("data:image/png;base64,"));
1738 } else {
1739 panic!("Expected image content part");
1740 }
1741 }
1742 _ => panic!("Expected parts content"),
1743 }
1744 }
1745
1746 #[test]
1747 fn test_from_message_with_images_unresolved_image() {
1748 let image_id = uuid::Uuid::new_v4();
1749 let message = Message {
1750 id: uuid::Uuid::new_v4().into(),
1751 role: MessageRole::User,
1752 content: vec![ContentPart::ImageFile(ImageFileContentPart {
1753 image_id: image_id.into(),
1754 filename: Some("missing.png".to_string()),
1755 })],
1756 phase: None,
1757 thinking: None,
1758 thinking_signature: None,
1759 controls: None,
1760 metadata: None,
1761 external_actor: None,
1762 created_at: chrono::Utc::now(),
1763 };
1764
1765 let resolved = std::collections::HashMap::new();
1767 let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
1768
1769 match &llm_message.content {
1772 LlmMessageContent::Text(text) => {
1773 assert!(text.contains("Image not found"));
1774 }
1775 LlmMessageContent::Parts(parts) => {
1776 assert_eq!(parts.len(), 1);
1777 if let LlmContentPart::Text { text } = &parts[0] {
1778 assert!(text.contains("Image not found"));
1779 } else {
1780 panic!("Expected text placeholder for missing image");
1781 }
1782 }
1783 }
1784 }
1785
1786 #[test]
1787 fn test_prepend_text_prefix_simple_text() {
1788 let mut msg = LlmMessage::text(LlmMessageRole::User, "Hello bot");
1789 msg.prepend_text_prefix("[Alice] ");
1790 assert_eq!(msg.content_as_text(), "[Alice] Hello bot");
1791 }
1792
1793 #[test]
1794 fn test_prepend_text_prefix_parts() {
1795 let mut msg = LlmMessage::parts(
1796 LlmMessageRole::User,
1797 vec![
1798 LlmContentPart::Text {
1799 text: "Hello".to_string(),
1800 },
1801 LlmContentPart::Image {
1802 url: "data:image/png;base64,abc".to_string(),
1803 },
1804 ],
1805 );
1806 msg.prepend_text_prefix("[Bob] ");
1807 match &msg.content {
1808 LlmMessageContent::Parts(parts) => {
1809 if let LlmContentPart::Text { text } = &parts[0] {
1810 assert_eq!(text, "[Bob] Hello");
1811 } else {
1812 panic!("Expected text part");
1813 }
1814 }
1815 _ => panic!("Expected parts content"),
1816 }
1817 }
1818
1819 #[test]
1820 fn test_prepend_text_prefix_parts_no_text() {
1821 let mut msg = LlmMessage::parts(
1822 LlmMessageRole::User,
1823 vec![LlmContentPart::Image {
1824 url: "data:image/png;base64,abc".to_string(),
1825 }],
1826 );
1827 msg.prepend_text_prefix("[Eve] ");
1828 match &msg.content {
1829 LlmMessageContent::Parts(parts) => {
1830 assert_eq!(parts.len(), 2);
1831 if let LlmContentPart::Text { text } = &parts[0] {
1832 assert_eq!(text, "[Eve] ");
1833 } else {
1834 panic!("Expected prepended text part");
1835 }
1836 }
1837 _ => panic!("Expected parts content"),
1838 }
1839 }
1840}