1use crate::credential_schema::CredentialFormSchema;
18use crate::error::{AgentLoopError, Result};
19use crate::openresponses_protocol::{CompactRequest, CompactResponse};
20use crate::runtime_agent::RuntimeAgent;
21use crate::tool_types::{ToolCall, ToolDefinition};
22use async_trait::async_trait;
23use chrono::{DateTime, Utc};
24use futures::Stream;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::pin::Pin;
28use std::sync::Arc;
29
30pub type LlmResponseStream = Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>;
36
37#[derive(Debug, Clone)]
39pub enum LlmStreamEvent {
40 TextDelta(String),
42 ThinkingDelta(String),
44 ThinkingSignature(String),
47 ReasonItem {
53 provider: String,
55 model: Option<String>,
57 item_id: String,
59 encrypted_content: Option<String>,
61 summary: Vec<String>,
63 token_count: Option<u32>,
65 },
66 ToolCalls(Vec<ToolCall>),
68 Done(Box<LlmCompletionMetadata>),
70 Error(String),
72}
73
74#[derive(Debug, Clone)]
85pub struct DiscoveredModel {
86 pub model_id: String,
88 pub display_name: Option<String>,
90 pub created_at: Option<DateTime<Utc>>,
92 pub owned_by: Option<String>,
94 pub discovered_profile: Option<crate::model::ModelProfile>,
97}
98
99#[derive(Debug, Clone, Default)]
107pub struct LlmCompletionMetadata {
108 pub total_tokens: Option<u32>,
110 pub prompt_tokens: Option<u32>,
112 pub completion_tokens: Option<u32>,
114 pub cache_read_tokens: Option<u32>,
116 pub cache_creation_tokens: Option<u32>,
118 pub provider_cost_usd: Option<f64>,
122 pub model: Option<String>,
124 pub finish_reason: Option<String>,
126 pub retry_metadata: Option<crate::llm_retry::RetryMetadata>,
128 pub response_id: Option<String>,
131 pub phase: Option<String>,
135}
136
137#[async_trait]
159pub trait ChatDriver: Send + Sync {
160 async fn chat_completion_stream(
162 &self,
163 messages: Vec<LlmMessage>,
164 config: &LlmCallConfig,
165 ) -> Result<LlmResponseStream>;
166
167 async fn chat_completion(
169 &self,
170 messages: Vec<LlmMessage>,
171 config: &LlmCallConfig,
172 ) -> Result<LlmResponse> {
173 use futures::StreamExt;
174
175 let mut stream = self.chat_completion_stream(messages, config).await?;
176 let mut text = String::new();
177 let mut thinking = String::new();
178 let mut thinking_signature: Option<String> = None;
179 let mut tool_calls = Vec::new();
180 let mut metadata = LlmCompletionMetadata::default();
181
182 while let Some(event) = stream.next().await {
183 match event? {
184 LlmStreamEvent::TextDelta(delta) => text.push_str(&delta),
185 LlmStreamEvent::ThinkingDelta(delta) => thinking.push_str(&delta),
186 LlmStreamEvent::ThinkingSignature(sig) => thinking_signature = Some(sig),
187 LlmStreamEvent::ReasonItem {
188 encrypted_content, ..
189 } => {
190 if let Some(sig) = encrypted_content {
191 thinking_signature = Some(sig);
192 }
193 }
194 LlmStreamEvent::ToolCalls(calls) => tool_calls = calls,
195 LlmStreamEvent::Done(meta) => metadata = *meta,
196 LlmStreamEvent::Error(err) => return Err(crate::error::AgentLoopError::llm(err)),
197 }
198 }
199
200 Ok(LlmResponse {
201 text,
202 thinking: if thinking.is_empty() {
203 None
204 } else {
205 Some(thinking)
206 },
207 thinking_signature,
208 tool_calls: if tool_calls.is_empty() {
209 None
210 } else {
211 Some(tool_calls)
212 },
213 metadata,
214 })
215 }
216
217 async fn list_models(&self) -> Result<Option<Vec<DiscoveredModel>>> {
225 Ok(None)
227 }
228
229 fn supports_compact(&self) -> bool {
238 false
240 }
241
242 fn supports_parallel_tool_calls(&self, _model: &str) -> bool {
253 false
254 }
255
256 async fn compact(&self, _request: CompactRequest) -> Result<Option<CompactResponse>> {
276 Ok(None)
278 }
279}
280
281#[async_trait]
283impl ChatDriver for Box<dyn ChatDriver> {
284 async fn chat_completion_stream(
285 &self,
286 messages: Vec<LlmMessage>,
287 config: &LlmCallConfig,
288 ) -> Result<LlmResponseStream> {
289 (**self).chat_completion_stream(messages, config).await
290 }
291
292 async fn chat_completion(
293 &self,
294 messages: Vec<LlmMessage>,
295 config: &LlmCallConfig,
296 ) -> Result<LlmResponse> {
297 (**self).chat_completion(messages, config).await
298 }
299
300 async fn list_models(&self) -> Result<Option<Vec<DiscoveredModel>>> {
301 (**self).list_models().await
302 }
303
304 fn supports_compact(&self) -> bool {
305 (**self).supports_compact()
306 }
307
308 fn supports_parallel_tool_calls(&self, model: &str) -> bool {
309 (**self).supports_parallel_tool_calls(model)
310 }
311
312 async fn compact(&self, request: CompactRequest) -> Result<Option<CompactResponse>> {
313 (**self).compact(request).await
314 }
315}
316
317#[derive(Debug, Clone)]
323pub struct LlmMessage {
324 pub role: LlmMessageRole,
325 pub content: LlmMessageContent,
326 pub tool_calls: Option<Vec<ToolCall>>,
327 pub tool_call_id: Option<String>,
328 pub phase: Option<crate::message::ExecutionPhase>,
333 pub thinking: Option<String>,
336 pub thinking_signature: Option<String>,
339}
340
341impl LlmMessage {
342 pub fn text(role: LlmMessageRole, content: impl Into<String>) -> Self {
344 Self {
345 role,
346 content: LlmMessageContent::Text(content.into()),
347 tool_calls: None,
348 tool_call_id: None,
349 phase: None,
350 thinking: None,
351 thinking_signature: None,
352 }
353 }
354
355 pub fn parts(role: LlmMessageRole, parts: Vec<LlmContentPart>) -> Self {
357 Self {
358 role,
359 content: LlmMessageContent::Parts(parts),
360 tool_calls: None,
361 tool_call_id: None,
362 phase: None,
363 thinking: None,
364 thinking_signature: None,
365 }
366 }
367
368 pub fn content_as_text(&self) -> String {
370 self.content.to_text()
371 }
372
373 pub fn prepend_text_prefix(&mut self, prefix: &str) {
378 match &mut self.content {
379 LlmMessageContent::Text(text) => {
380 *text = format!("{}{}", prefix, text);
381 }
382 LlmMessageContent::Parts(parts) => {
383 for part in parts.iter_mut() {
384 if let LlmContentPart::Text { text } = part {
385 *text = format!("{}{}", prefix, text);
386 return;
387 }
388 }
389 parts.insert(
391 0,
392 LlmContentPart::Text {
393 text: prefix.to_string(),
394 },
395 );
396 }
397 }
398 }
399}
400
401pub fn fold_system_messages(messages: &[LlmMessage]) -> Option<String> {
412 let mut system: Option<String> = None;
413 for msg in messages {
414 if msg.role == LlmMessageRole::System {
415 let text = msg.content.to_text();
416 system = Some(match system.take() {
417 Some(existing) if !existing.is_empty() => format!("{existing}\n\n{text}"),
418 _ => text,
419 });
420 }
421 }
422 system
423}
424
425#[derive(Debug, Clone)]
427pub enum LlmMessageContent {
428 Text(String),
430 Parts(Vec<LlmContentPart>),
432}
433
434impl LlmMessageContent {
435 pub fn to_text(&self) -> String {
437 match self {
438 LlmMessageContent::Text(s) => s.clone(),
439 LlmMessageContent::Parts(parts) => parts
440 .iter()
441 .filter_map(|p| match p {
442 LlmContentPart::Text { text } => Some(text.clone()),
443 _ => None,
444 })
445 .collect::<Vec<_>>()
446 .join(""),
447 }
448 }
449
450 pub fn is_text(&self) -> bool {
452 matches!(self, LlmMessageContent::Text(_))
453 }
454
455 pub fn is_parts(&self) -> bool {
457 matches!(self, LlmMessageContent::Parts(_))
458 }
459}
460
461impl From<String> for LlmMessageContent {
462 fn from(s: String) -> Self {
463 LlmMessageContent::Text(s)
464 }
465}
466
467impl From<&str> for LlmMessageContent {
468 fn from(s: &str) -> Self {
469 LlmMessageContent::Text(s.to_string())
470 }
471}
472
473#[derive(Debug, Clone)]
475pub enum LlmContentPart {
476 Text { text: String },
478 Image { url: String },
480 Audio { url: String },
482}
483
484impl LlmContentPart {
485 pub fn text(text: impl Into<String>) -> Self {
487 LlmContentPart::Text { text: text.into() }
488 }
489
490 pub fn image(url: impl Into<String>) -> Self {
492 LlmContentPart::Image { url: url.into() }
493 }
494
495 pub fn audio(url: impl Into<String>) -> Self {
497 LlmContentPart::Audio { url: url.into() }
498 }
499}
500
501#[derive(Debug, Clone, PartialEq, Eq)]
503pub enum LlmMessageRole {
504 System,
505 User,
506 Assistant,
507 Tool,
508}
509
510#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
520pub struct ToolSearchConfig {
521 pub enabled: bool,
523 pub threshold: usize,
526}
527
528#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
530#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
531#[serde(rename_all = "snake_case")]
532pub enum PromptCacheStrategy {
533 #[default]
535 Auto,
536}
537
538#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
543#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
544pub struct PromptCacheConfig {
545 pub enabled: bool,
547 #[serde(default)]
549 pub strategy: PromptCacheStrategy,
550 #[serde(default, skip_serializing_if = "Option::is_none")]
557 pub gemini_cached_content: Option<String>,
558}
559
560#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
570#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
571#[serde(tag = "kind", rename_all = "snake_case")]
572pub enum OpenRouterRoutingPreset {
573 CheapestWithTools,
575 LowestLatencyReview,
577 ZdrOnly,
579 ByokFirst,
581 NoDataCollection,
583 StrictJson,
585 ReasoningRequired,
587 MaxPrice {
590 #[serde(default, skip_serializing_if = "Option::is_none")]
592 prompt_usd_per_million: Option<f64>,
593 #[serde(default, skip_serializing_if = "Option::is_none")]
595 completion_usd_per_million: Option<f64>,
596 },
597}
598
599#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
607#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
608#[serde(rename_all = "snake_case")]
609pub enum OpenRouterCapacityStrategy {
610 #[default]
612 SharedCapacity,
613 ByokFirst,
617 ByokOnly,
622}
623
624#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
632#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
633#[serde(rename_all = "snake_case")]
634pub enum OpenRouterServerToolKind {
635 WebSearch,
636 WebFetch,
637 Datetime,
638 ImageGeneration,
639 ApplyPatch,
640 Fusion,
641 Advisor,
642 Subagent,
643}
644
645impl OpenRouterServerToolKind {
646 pub const ALL: [OpenRouterServerToolKind; 8] = [
648 Self::WebSearch,
649 Self::WebFetch,
650 Self::Datetime,
651 Self::ImageGeneration,
652 Self::ApplyPatch,
653 Self::Fusion,
654 Self::Advisor,
655 Self::Subagent,
656 ];
657
658 pub fn name(&self) -> &'static str {
660 match self {
661 Self::WebSearch => "web_search",
662 Self::WebFetch => "web_fetch",
663 Self::Datetime => "datetime",
664 Self::ImageGeneration => "image_generation",
665 Self::ApplyPatch => "apply_patch",
666 Self::Fusion => "fusion",
667 Self::Advisor => "advisor",
668 Self::Subagent => "subagent",
669 }
670 }
671
672 pub fn display_name(&self) -> &'static str {
674 match self {
675 Self::WebSearch => "Web Search",
676 Self::WebFetch => "Web Fetch",
677 Self::Datetime => "Date & Time",
678 Self::ImageGeneration => "Image Generation",
679 Self::ApplyPatch => "Apply Patch",
680 Self::Fusion => "Fusion",
681 Self::Advisor => "Advisor",
682 Self::Subagent => "Subagent",
683 }
684 }
685
686 pub fn wire_type(&self) -> String {
689 format!("openrouter:{}", self.name())
690 }
691
692 pub fn from_name(name: &str) -> Option<Self> {
694 Self::ALL.into_iter().find(|kind| kind.name() == name)
695 }
696}
697
698#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
702#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
703pub struct OpenRouterServerTool {
704 pub kind: OpenRouterServerToolKind,
705 #[serde(default, skip_serializing_if = "Option::is_none")]
706 #[cfg_attr(feature = "openapi", schema(value_type = Option<Object>))]
707 pub parameters: Option<serde_json::Value>,
708}
709
710impl OpenRouterServerTool {
711 pub fn new(kind: OpenRouterServerToolKind) -> Self {
713 Self {
714 kind,
715 parameters: None,
716 }
717 }
718
719 pub fn with_parameters(kind: OpenRouterServerToolKind, parameters: serde_json::Value) -> Self {
721 Self {
722 kind,
723 parameters: Some(parameters),
724 }
725 }
726}
727
728#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
731#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
732pub struct OpenRouterRoutingConfig {
733 #[serde(default, skip_serializing_if = "Vec::is_empty")]
735 pub models: Vec<String>,
736 #[serde(default, skip_serializing_if = "Option::is_none")]
739 pub route: Option<OpenRouterRoute>,
740 #[serde(default, skip_serializing_if = "Option::is_none")]
742 pub provider: Option<OpenRouterProviderRouting>,
743 #[serde(default, skip_serializing_if = "Option::is_none")]
745 pub plugins: Option<OpenRouterPluginConfig>,
746 #[serde(default, skip_serializing_if = "Option::is_none")]
750 pub capacity_strategy: Option<OpenRouterCapacityStrategy>,
751 #[serde(default, skip_serializing_if = "Vec::is_empty")]
755 pub presets: Vec<OpenRouterRoutingPreset>,
756 #[serde(default, skip_serializing_if = "Vec::is_empty")]
759 pub server_tools: Vec<OpenRouterServerTool>,
760}
761
762impl OpenRouterRoutingConfig {
763 pub fn is_empty(&self) -> bool {
764 self.models.is_empty()
765 && self.route.is_none()
766 && self.provider.is_none()
767 && self.plugins.as_ref().is_none_or(|p| p.is_empty())
768 && matches!(
769 self.capacity_strategy,
770 None | Some(OpenRouterCapacityStrategy::SharedCapacity)
771 )
772 && self.presets.is_empty()
773 && self.server_tools.is_empty()
774 }
775
776 pub fn fallback_models(models: impl IntoIterator<Item = impl Into<String>>) -> Self {
778 let models = models.into_iter().map(Into::into).collect::<Vec<_>>();
779 let route = (!models.is_empty()).then_some(OpenRouterRoute::Fallback);
780 Self {
781 models,
782 route,
783 provider: None,
784 plugins: None,
785 capacity_strategy: None,
786 presets: vec![],
787 server_tools: vec![],
788 }
789 }
790
791 pub fn validate_for_primary_model(
792 &self,
793 primary_model: &str,
794 ) -> std::result::Result<(), String> {
795 if self.route == Some(OpenRouterRoute::Fallback) && self.models.is_empty() {
796 return Err(
797 "OpenRouter fallback routing requires at least one model in `models`".to_string(),
798 );
799 }
800
801 if let Some(first_model) = self.models.first()
802 && first_model != primary_model
803 {
804 return Err(format!(
805 "OpenRouter routing models[0] ('{first_model}') must match primary model ('{primary_model}')"
806 ));
807 }
808
809 Ok(())
810 }
811
812 pub fn apply_capacity_strategy(&self) -> std::result::Result<Self, String> {
822 match self.capacity_strategy {
823 None | Some(OpenRouterCapacityStrategy::SharedCapacity) => Ok(self.clone()),
824 Some(OpenRouterCapacityStrategy::ByokFirst) => {
825 let mut result = self.clone();
826 let provider = result.provider.get_or_insert_with(Default::default);
827 if provider.allow_fallbacks.is_none() {
828 provider.allow_fallbacks = Some(true);
829 }
830 Ok(result)
831 }
832 Some(OpenRouterCapacityStrategy::ByokOnly) => {
833 let only_is_empty = self.provider.as_ref().is_none_or(|p| p.only.is_empty());
834 if only_is_empty {
835 return Err(
836 "OpenRouter BYOK-only strategy requires provider.only to list at least \
837 one upstream provider slug. Configure the provider list to match the \
838 BYOK providers registered in your OpenRouter workspace."
839 .to_string(),
840 );
841 }
842 let mut result = self.clone();
843 let provider = result.provider.get_or_insert_with(Default::default);
844 provider.allow_fallbacks = Some(false);
845 Ok(result)
846 }
847 }
848 }
849
850 pub fn apply_presets(&self) -> std::result::Result<Self, String> {
860 if self.presets.is_empty() {
861 return Ok(self.clone());
862 }
863
864 let mut derived = OpenRouterProviderRouting::default();
865
866 for preset in &self.presets {
867 match preset {
868 OpenRouterRoutingPreset::CheapestWithTools => {
869 derived.require_parameters = Some(true);
870 derived.sort = Some(OpenRouterProviderSort::Simple(
871 OpenRouterProviderSortBy::Price,
872 ));
873 }
874 OpenRouterRoutingPreset::LowestLatencyReview => {
875 derived.sort = Some(OpenRouterProviderSort::Simple(
876 OpenRouterProviderSortBy::Throughput,
877 ));
878 }
879 OpenRouterRoutingPreset::ZdrOnly => {
880 derived.zdr = Some(true);
881 }
882 OpenRouterRoutingPreset::ByokFirst => {
883 if derived.allow_fallbacks.is_none() {
884 derived.allow_fallbacks = Some(true);
885 }
886 }
887 OpenRouterRoutingPreset::NoDataCollection => {
888 derived.data_collection = Some(OpenRouterDataCollection::Deny);
889 }
890 OpenRouterRoutingPreset::StrictJson
891 | OpenRouterRoutingPreset::ReasoningRequired => {
892 derived.require_parameters = Some(true);
893 }
894 OpenRouterRoutingPreset::MaxPrice {
895 prompt_usd_per_million,
896 completion_usd_per_million,
897 } => {
898 if prompt_usd_per_million.is_some_and(|v| v < 0.0)
899 || completion_usd_per_million.is_some_and(|v| v < 0.0)
900 {
901 return Err(
902 "MaxPrice preset values must be non-negative USD per million tokens"
903 .to_string(),
904 );
905 }
906 if prompt_usd_per_million.is_some() || completion_usd_per_million.is_some() {
907 let mp = derived.max_price.get_or_insert_with(Default::default);
908 if let Some(p) = prompt_usd_per_million {
909 mp.prompt = Some(p / 1_000_000.0);
910 }
911 if let Some(c) = completion_usd_per_million {
912 mp.completion = Some(c / 1_000_000.0);
913 }
914 }
915 }
916 }
917 }
918
919 let merged = merge_provider_routing(derived, self.provider.clone().unwrap_or_default());
921
922 let mut result = self.clone();
923 result.presets = vec![];
924 result.provider = if merged.is_empty() {
925 None
926 } else {
927 Some(merged)
928 };
929 Ok(result)
930 }
931}
932
933fn merge_provider_routing(
937 derived: OpenRouterProviderRouting,
938 explicit: OpenRouterProviderRouting,
939) -> OpenRouterProviderRouting {
940 OpenRouterProviderRouting {
941 order: if !explicit.order.is_empty() {
942 explicit.order
943 } else {
944 derived.order
945 },
946 only: if !explicit.only.is_empty() {
947 explicit.only
948 } else {
949 derived.only
950 },
951 ignore: if !explicit.ignore.is_empty() {
952 explicit.ignore
953 } else {
954 derived.ignore
955 },
956 allow_fallbacks: explicit.allow_fallbacks.or(derived.allow_fallbacks),
957 require_parameters: explicit.require_parameters.or(derived.require_parameters),
958 data_collection: explicit.data_collection.or(derived.data_collection),
959 zdr: explicit.zdr.or(derived.zdr),
960 enforce_distillable_text: explicit
961 .enforce_distillable_text
962 .or(derived.enforce_distillable_text),
963 quantizations: if !explicit.quantizations.is_empty() {
964 explicit.quantizations
965 } else {
966 derived.quantizations
967 },
968 sort: explicit.sort.or(derived.sort),
969 max_price: explicit.max_price.or(derived.max_price),
970 }
971}
972
973#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
975#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
976#[serde(rename_all = "snake_case")]
977pub enum OpenRouterRoute {
978 Fallback,
979}
980
981#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
983#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
984pub struct OpenRouterProviderRouting {
985 #[serde(default, skip_serializing_if = "Vec::is_empty")]
987 pub order: Vec<String>,
988 #[serde(default, skip_serializing_if = "Vec::is_empty")]
990 pub only: Vec<String>,
991 #[serde(default, skip_serializing_if = "Vec::is_empty")]
993 pub ignore: Vec<String>,
994 #[serde(default, skip_serializing_if = "Option::is_none")]
996 pub allow_fallbacks: Option<bool>,
997 #[serde(default, skip_serializing_if = "Option::is_none")]
999 pub require_parameters: Option<bool>,
1000 #[serde(default, skip_serializing_if = "Option::is_none")]
1002 pub data_collection: Option<OpenRouterDataCollection>,
1003 #[serde(default, skip_serializing_if = "Option::is_none")]
1005 pub zdr: Option<bool>,
1006 #[serde(default, skip_serializing_if = "Option::is_none")]
1008 pub enforce_distillable_text: Option<bool>,
1009 #[serde(default, skip_serializing_if = "Vec::is_empty")]
1011 pub quantizations: Vec<String>,
1012 #[serde(default, skip_serializing_if = "Option::is_none")]
1014 pub sort: Option<OpenRouterProviderSort>,
1015 #[serde(default, skip_serializing_if = "Option::is_none")]
1017 pub max_price: Option<OpenRouterMaxPrice>,
1018}
1019
1020impl OpenRouterProviderRouting {
1021 pub fn is_empty(&self) -> bool {
1022 self.order.is_empty()
1023 && self.only.is_empty()
1024 && self.ignore.is_empty()
1025 && self.allow_fallbacks.is_none()
1026 && self.require_parameters.is_none()
1027 && self.data_collection.is_none()
1028 && self.zdr.is_none()
1029 && self.enforce_distillable_text.is_none()
1030 && self.quantizations.is_empty()
1031 && self.sort.is_none()
1032 && self.max_price.is_none()
1033 }
1034}
1035
1036#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
1038#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1039#[serde(rename_all = "snake_case")]
1040pub enum OpenRouterDataCollection {
1041 Allow,
1042 Deny,
1043}
1044
1045#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
1047#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1048#[serde(untagged)]
1049pub enum OpenRouterProviderSort {
1050 Simple(OpenRouterProviderSortBy),
1051 Advanced(OpenRouterProviderSortOptions),
1052}
1053
1054#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
1056#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1057#[serde(rename_all = "snake_case")]
1058pub enum OpenRouterProviderSortBy {
1059 Price,
1060 Throughput,
1061 Latency,
1062}
1063
1064#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
1066#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1067pub struct OpenRouterProviderSortOptions {
1068 pub by: OpenRouterProviderSortBy,
1069 #[serde(default, skip_serializing_if = "Option::is_none")]
1070 pub partition: Option<OpenRouterSortPartition>,
1071}
1072
1073#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
1075#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1076#[serde(rename_all = "snake_case")]
1077pub enum OpenRouterSortPartition {
1078 Model,
1079 None,
1080}
1081
1082#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
1085#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1086pub struct OpenRouterMaxPrice {
1087 #[serde(default, skip_serializing_if = "Option::is_none")]
1088 pub prompt: Option<f64>,
1089 #[serde(default, skip_serializing_if = "Option::is_none")]
1090 pub completion: Option<f64>,
1091 #[serde(default, skip_serializing_if = "Option::is_none")]
1092 pub request: Option<f64>,
1093 #[serde(default, skip_serializing_if = "Option::is_none")]
1094 pub image: Option<f64>,
1095}
1096
1097#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
1103#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1104pub struct OpenRouterWebSearchPlugin {
1105 #[serde(default, skip_serializing_if = "Option::is_none")]
1107 pub max_results: Option<u32>,
1108 #[serde(default, skip_serializing_if = "Option::is_none")]
1110 pub search_prompt: Option<String>,
1111}
1112
1113#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
1118#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1119pub struct OpenRouterFilePlugin {}
1120
1121#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
1126#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
1127pub struct OpenRouterPluginConfig {
1128 #[serde(default, skip_serializing_if = "Option::is_none")]
1130 pub web: Option<OpenRouterWebSearchPlugin>,
1131 #[serde(default, skip_serializing_if = "Option::is_none")]
1133 pub file: Option<OpenRouterFilePlugin>,
1134}
1135
1136impl OpenRouterPluginConfig {
1137 pub fn is_empty(&self) -> bool {
1138 self.web.is_none() && self.file.is_none()
1139 }
1140}
1141
1142pub const OPENROUTER_HTTP_REFERER_METADATA_KEY: &str = "openrouter.http_referer";
1144pub const OPENROUTER_X_TITLE_METADATA_KEY: &str = "openrouter.x_title";
1146
1147#[derive(Debug, Clone)]
1149pub struct LlmCallConfig {
1150 pub model: String,
1151 pub temperature: Option<f32>,
1152 pub max_tokens: Option<u32>,
1153 pub tools: Vec<ToolDefinition>,
1154 pub reasoning_effort: Option<String>,
1156 pub metadata: HashMap<String, String>,
1160 pub previous_response_id: Option<String>,
1163 pub tool_search: Option<ToolSearchConfig>,
1165 pub prompt_cache: Option<PromptCacheConfig>,
1167 pub openrouter_routing: Option<OpenRouterRoutingConfig>,
1169 pub parallel_tool_calls: Option<bool>,
1176}
1177
1178impl LlmCallConfig {
1179 pub fn resolved_parallel_tool_calls(&self, supported: bool) -> Option<bool> {
1189 if supported {
1190 self.parallel_tool_calls
1191 } else {
1192 None
1193 }
1194 }
1195}
1196
1197impl From<&RuntimeAgent> for LlmCallConfig {
1198 fn from(runtime_agent: &RuntimeAgent) -> Self {
1199 Self {
1200 model: runtime_agent.model.clone(),
1201 temperature: runtime_agent.temperature,
1202 max_tokens: runtime_agent.max_tokens,
1203 tools: runtime_agent.tools.clone(),
1204 reasoning_effort: None, metadata: HashMap::new(), previous_response_id: None,
1207 tool_search: runtime_agent.tool_search.clone(),
1208 prompt_cache: runtime_agent.prompt_cache.clone(),
1209 openrouter_routing: runtime_agent.openrouter_routing.clone(),
1210 parallel_tool_calls: runtime_agent.parallel_tool_calls,
1211 }
1212 }
1213}
1214
1215#[derive(Debug, Clone)]
1217pub struct LlmResponse {
1218 pub text: String,
1219 pub thinking: Option<String>,
1221 pub thinking_signature: Option<String>,
1223 pub tool_calls: Option<Vec<ToolCall>>,
1224 pub metadata: LlmCompletionMetadata,
1225}
1226
1227pub struct LlmCallConfigBuilder {
1246 config: LlmCallConfig,
1247}
1248
1249impl LlmCallConfigBuilder {
1250 pub fn from(runtime_agent: &RuntimeAgent) -> Self {
1252 Self {
1253 config: LlmCallConfig::from(runtime_agent),
1254 }
1255 }
1256
1257 pub fn reasoning_effort(mut self, effort: impl Into<String>) -> Self {
1259 self.config.reasoning_effort = Some(effort.into());
1260 self
1261 }
1262
1263 pub fn model(mut self, model: impl Into<String>) -> Self {
1265 self.config.model = model.into();
1266 self
1267 }
1268
1269 pub fn temperature(mut self, temp: f32) -> Self {
1271 self.config.temperature = Some(temp);
1272 self
1273 }
1274
1275 pub fn max_tokens(mut self, tokens: u32) -> Self {
1277 self.config.max_tokens = Some(tokens);
1278 self
1279 }
1280
1281 pub fn tools(mut self, tools: Vec<ToolDefinition>) -> Self {
1283 self.config.tools = tools;
1284 self
1285 }
1286
1287 pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
1292 self.config.metadata = metadata;
1293 self
1294 }
1295
1296 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1298 self.config.metadata.insert(key.into(), value.into());
1299 self
1300 }
1301
1302 pub fn previous_response_id(mut self, id: Option<String>) -> Self {
1304 self.config.previous_response_id = id;
1305 self
1306 }
1307
1308 pub fn tool_search(mut self, config: ToolSearchConfig) -> Self {
1310 self.config.tool_search = Some(config);
1311 self
1312 }
1313
1314 pub fn prompt_cache(mut self, config: PromptCacheConfig) -> Self {
1316 self.config.prompt_cache = Some(config);
1317 self
1318 }
1319
1320 pub fn openrouter_routing(mut self, config: OpenRouterRoutingConfig) -> Self {
1322 self.config.openrouter_routing = (!config.is_empty()).then_some(config);
1323 self
1324 }
1325
1326 pub fn parallel_tool_calls(mut self, parallel_tool_calls: Option<bool>) -> Self {
1328 self.config.parallel_tool_calls = parallel_tool_calls;
1329 self
1330 }
1331
1332 pub fn build(self) -> LlmCallConfig {
1334 self.config
1335 }
1336}
1337
1338impl From<&crate::message::Message> for LlmMessage {
1343 fn from(msg: &crate::message::Message) -> Self {
1349 let role = match msg.role {
1350 crate::message::MessageRole::System => LlmMessageRole::System,
1351 crate::message::MessageRole::User => LlmMessageRole::User,
1352 crate::message::MessageRole::Agent => LlmMessageRole::Assistant,
1353 crate::message::MessageRole::ToolResult => LlmMessageRole::Tool,
1354 };
1355
1356 let tool_calls: Vec<ToolCall> = msg
1358 .tool_calls()
1359 .into_iter()
1360 .map(|tc| ToolCall {
1361 id: tc.id.clone(),
1362 name: tc.name.clone(),
1363 arguments: tc.arguments.clone(),
1364 })
1365 .collect();
1366
1367 LlmMessage {
1368 role,
1369 content: LlmMessageContent::Text(msg.content_to_llm_string()),
1370 tool_calls: if tool_calls.is_empty() {
1371 None
1372 } else {
1373 Some(tool_calls)
1374 },
1375 tool_call_id: msg.tool_call_id().map(|s| s.to_string()),
1376 phase: msg.phase,
1377 thinking: msg.thinking.clone(),
1378 thinking_signature: msg.thinking_signature.clone(),
1379 }
1380 }
1381}
1382
1383use crate::traits::ResolvedImage;
1388use uuid::Uuid;
1389
1390impl LlmMessage {
1391 pub fn from_message_with_images(
1411 msg: &crate::message::Message,
1412 resolved_images: &HashMap<Uuid, ResolvedImage>,
1413 ) -> Self {
1414 use crate::message::{ContentPart, MessageRole};
1415
1416 let role = match msg.role {
1417 MessageRole::System => LlmMessageRole::System,
1418 MessageRole::User => LlmMessageRole::User,
1419 MessageRole::Agent => LlmMessageRole::Assistant,
1420 MessageRole::ToolResult => LlmMessageRole::Tool,
1421 };
1422
1423 let mut parts: Vec<LlmContentPart> = Vec::new();
1425 let mut tool_calls: Vec<ToolCall> = Vec::new();
1426
1427 for part in &msg.content {
1428 match part {
1429 ContentPart::Text(t) => {
1430 parts.push(LlmContentPart::Text {
1431 text: t.text.clone(),
1432 });
1433 }
1434 ContentPart::Image(img) => {
1435 if let Some(url) = &img.url {
1437 parts.push(LlmContentPart::Image { url: url.clone() });
1438 } else if let (Some(base64), Some(media_type)) = (&img.base64, &img.media_type)
1439 {
1440 let data_url = format!("data:{};base64,{}", media_type, base64);
1441 parts.push(LlmContentPart::Image { url: data_url });
1442 }
1443 }
1444 ContentPart::ImageFile(img_file) => {
1445 if let Some(resolved) = resolved_images.get(&img_file.image_id.uuid()) {
1447 parts.push(LlmContentPart::Image {
1448 url: resolved.to_data_url(),
1449 });
1450 } else {
1451 parts.push(LlmContentPart::Text {
1453 text: format!("[Image not found: {}]", img_file.image_id),
1454 });
1455 }
1456 }
1457 ContentPart::ToolCall(tc) => {
1458 tool_calls.push(ToolCall {
1460 id: tc.id.clone(),
1461 name: tc.name.clone(),
1462 arguments: tc.arguments.clone(),
1463 });
1464 }
1465 ContentPart::ToolResult(tr) => {
1466 let text = if let Some(err) = &tr.error {
1468 format!("Tool error: {}", err)
1469 } else if let Some(res) = &tr.result {
1470 serde_json::to_string(res).unwrap_or_else(|_| "{}".to_string())
1471 } else {
1472 "{}".to_string()
1473 };
1474 let text = truncate_tool_result(text);
1478 parts.push(LlmContentPart::Text { text });
1479 }
1480 }
1481 }
1482
1483 let content = if parts.len() == 1 && matches!(&parts[0], LlmContentPart::Text { .. }) {
1485 if let LlmContentPart::Text { text } = &parts[0] {
1487 LlmMessageContent::Text(text.clone())
1488 } else {
1489 LlmMessageContent::Parts(parts)
1490 }
1491 } else if parts.is_empty() {
1492 LlmMessageContent::Text(String::new())
1494 } else {
1495 LlmMessageContent::Parts(parts)
1497 };
1498
1499 LlmMessage {
1500 role,
1501 content,
1502 tool_calls: if tool_calls.is_empty() {
1503 None
1504 } else {
1505 Some(tool_calls)
1506 },
1507 tool_call_id: msg.tool_call_id().map(|s| s.to_string()),
1508 phase: msg.phase,
1509 thinking: msg.thinking.clone(),
1510 thinking_signature: msg.thinking_signature.clone(),
1511 }
1512 }
1513
1514 pub fn message_has_image_files(msg: &crate::message::Message) -> bool {
1516 msg.content.iter().any(|p| p.is_image_file())
1517 }
1518
1519 pub fn extract_image_file_ids(msg: &crate::message::Message) -> Vec<Uuid> {
1521 msg.content
1522 .iter()
1523 .filter_map(|p| match p {
1524 crate::message::ContentPart::ImageFile(f) => Some(f.image_id.uuid()),
1525 _ => None,
1526 })
1527 .collect()
1528 }
1529}
1530
1531pub use crate::provider::DriverId;
1536
1537#[derive(Debug, Clone, Default, PartialEq, Eq)]
1543pub struct ProviderMetadata {
1544 pub refresh_token: Option<String>,
1546 pub account_id: Option<String>,
1548 pub extra: Option<serde_json::Value>,
1550}
1551
1552#[derive(Debug, Clone)]
1554pub struct ProviderConfig {
1555 pub provider_type: DriverId,
1557 pub api_key: Option<String>,
1559 pub base_url: Option<String>,
1561 pub metadata: ProviderMetadata,
1563}
1564
1565impl ProviderConfig {
1566 pub fn new(provider_type: DriverId) -> Self {
1568 Self {
1569 provider_type,
1570 api_key: None,
1571 base_url: None,
1572 metadata: ProviderMetadata::default(),
1573 }
1574 }
1575
1576 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
1578 self.api_key = Some(api_key.into());
1579 self
1580 }
1581
1582 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
1584 self.base_url = Some(base_url.into());
1585 self
1586 }
1587
1588 pub fn with_metadata(mut self, metadata: ProviderMetadata) -> Self {
1590 self.metadata = metadata;
1591 self
1592 }
1593}
1594
1595#[derive(Debug, Clone)]
1601pub struct DriverConfig {
1602 pub provider_type: DriverId,
1604 pub api_key: Option<String>,
1610 pub credentials: std::collections::BTreeMap<String, String>,
1616 pub base_url: Option<String>,
1618 pub metadata: ProviderMetadata,
1620}
1621
1622impl DriverConfig {
1623 pub fn from_provider_config(config: &ProviderConfig) -> Self {
1629 Self {
1630 provider_type: config.provider_type.clone(),
1631 credentials: crate::credential_schema::parse_credential_document(
1632 config.api_key.as_deref(),
1633 ),
1634 api_key: config.api_key.clone(),
1635 base_url: config.base_url.clone(),
1636 metadata: config.metadata.clone(),
1637 }
1638 }
1639
1640 pub fn credential(&self, name: &str) -> Option<&str> {
1642 self.credentials
1643 .get(name)
1644 .map(String::as_str)
1645 .filter(|s| !s.is_empty())
1646 }
1647}
1648
1649impl From<&crate::traits::ResolvedModel> for ProviderConfig {
1650 fn from(model: &crate::traits::ResolvedModel) -> Self {
1651 Self {
1652 provider_type: model.provider_type.clone(),
1653 api_key: model.api_key.clone(),
1654 base_url: model.base_url.clone(),
1655 metadata: model.provider_metadata.clone().unwrap_or_default(),
1656 }
1657 }
1658}
1659
1660pub type BoxedChatDriver = Box<dyn ChatDriver>;
1662
1663#[derive(Debug, Clone)]
1669pub struct EmbedRequest {
1670 pub texts: Vec<String>,
1672 pub model: String,
1674}
1675
1676#[derive(Debug, Clone)]
1678pub struct EmbedResponse {
1679 pub embeddings: Vec<Vec<f32>>,
1681 pub usage_tokens: Option<u32>,
1684}
1685
1686#[derive(Debug, thiserror::Error)]
1688pub enum EmbeddingsDriverError {
1689 #[error("embeddings provider returned an error: {0}")]
1690 Provider(String),
1691 #[error("embeddings request failed: {0}")]
1692 Transport(String),
1693}
1694
1695#[async_trait]
1701pub trait EmbeddingsDriver: Send + Sync {
1702 async fn embed(
1704 &self,
1705 request: EmbedRequest,
1706 ) -> std::result::Result<EmbedResponse, EmbeddingsDriverError>;
1707}
1708
1709pub type BoxedEmbeddingsDriver = Box<dyn EmbeddingsDriver>;
1711
1712pub type EmbeddingsDriverFactory =
1714 Arc<dyn Fn(&DriverConfig) -> BoxedEmbeddingsDriver + Send + Sync>;
1715
1716pub type DriverFactory = Arc<dyn Fn(&DriverConfig) -> BoxedChatDriver + Send + Sync>;
1725
1726#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
1732#[serde(rename_all = "snake_case")]
1733pub enum ServiceKind {
1734 Chat,
1736 Embeddings,
1738 Realtime,
1740 Images,
1742 Rerank,
1744}
1745
1746impl std::fmt::Display for ServiceKind {
1747 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1748 let s = match self {
1749 ServiceKind::Chat => "chat",
1750 ServiceKind::Embeddings => "embeddings",
1751 ServiceKind::Realtime => "realtime",
1752 ServiceKind::Images => "images",
1753 ServiceKind::Rerank => "rerank",
1754 };
1755 f.write_str(s)
1756 }
1757}
1758
1759#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1772pub enum DriverOAuthFlow {
1773 OpenRouterPkce,
1781}
1782
1783#[derive(Debug, Clone)]
1788pub struct DriverOAuthConfig {
1789 pub authorize_url: String,
1791 pub token_url: String,
1793 pub flow: DriverOAuthFlow,
1795}
1796
1797impl DriverOAuthConfig {
1798 pub fn openrouter() -> Self {
1800 Self {
1801 authorize_url: "https://openrouter.ai/auth".to_string(),
1802 token_url: "https://openrouter.ai/api/v1/auth/keys".to_string(),
1803 flow: DriverOAuthFlow::OpenRouterPkce,
1804 }
1805 }
1806}
1807
1808#[derive(Clone)]
1815pub struct DriverDescriptor {
1816 pub id: DriverId,
1818 pub display_name: String,
1820 pub services: Vec<ServiceKind>,
1822 pub credential_schema: CredentialFormSchema,
1824 pub oauth: Option<DriverOAuthConfig>,
1827 pub chat: Option<DriverFactory>,
1829 pub embeddings: Option<EmbeddingsDriverFactory>,
1831}
1832
1833impl DriverDescriptor {
1834 pub fn chat_only<F>(id: impl Into<DriverId>, factory: F) -> Self
1839 where
1840 F: Fn(&DriverConfig) -> BoxedChatDriver + Send + Sync + 'static,
1841 {
1842 let id = id.into();
1843 Self {
1844 display_name: default_display_name(&id),
1845 credential_schema: default_credential_schema(&id),
1846 services: vec![ServiceKind::Chat],
1847 oauth: None,
1848 chat: Some(Arc::new(factory)),
1849 embeddings: None,
1850 id,
1851 }
1852 }
1853
1854 pub fn supports(&self, service: ServiceKind) -> bool {
1856 self.services.contains(&service)
1857 }
1858}
1859
1860impl std::fmt::Debug for DriverDescriptor {
1861 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1862 f.debug_struct("DriverDescriptor")
1863 .field("id", &self.id)
1864 .field("display_name", &self.display_name)
1865 .field("services", &self.services)
1866 .field("oauth", &self.oauth.is_some())
1867 .field("chat", &self.chat.is_some())
1868 .field("embeddings", &self.embeddings.is_some())
1869 .finish()
1870 }
1871}
1872
1873fn default_display_name(id: &DriverId) -> String {
1874 match id {
1875 DriverId::OpenAI => "OpenAI".to_string(),
1876 DriverId::OpenRouter => "OpenRouter".to_string(),
1877 DriverId::AzureOpenAI => "Azure OpenAI".to_string(),
1878 DriverId::OpenAICompletions => "OpenAI (Chat Completions)".to_string(),
1879 DriverId::Anthropic => "Anthropic".to_string(),
1880 DriverId::Gemini => "Google Gemini".to_string(),
1881 DriverId::Bedrock => "AWS Bedrock".to_string(),
1882 DriverId::Mai => "Microsoft MAI".to_string(),
1883 DriverId::Fireworks => "Fireworks AI".to_string(),
1884 DriverId::LlmSim => "LLM Simulator".to_string(),
1885 DriverId::External(id) => id.to_string(),
1886 }
1887}
1888
1889fn default_credential_schema(id: &DriverId) -> CredentialFormSchema {
1890 match id {
1891 DriverId::LlmSim | DriverId::External(_) => CredentialFormSchema::empty(),
1893 _ => CredentialFormSchema::api_key(String::new()),
1894 }
1895}
1896
1897#[derive(Clone, Default)]
1917pub struct DriverRegistry {
1918 descriptors: HashMap<DriverId, DriverDescriptor>,
1919}
1920
1921impl DriverRegistry {
1922 pub fn new() -> Self {
1924 Self {
1925 descriptors: HashMap::new(),
1926 }
1927 }
1928
1929 pub fn register_descriptor(&mut self, descriptor: DriverDescriptor) {
1935 if self.descriptors.contains_key(&descriptor.id) {
1936 panic!(
1937 "driver already registered for provider '{}'; \
1938 use register_descriptor_or_replace to overwrite intentionally",
1939 descriptor.id
1940 );
1941 }
1942 self.descriptors.insert(descriptor.id.clone(), descriptor);
1943 }
1944
1945 pub fn register_descriptor_or_replace(&mut self, descriptor: DriverDescriptor) {
1947 self.descriptors.insert(descriptor.id.clone(), descriptor);
1948 }
1949
1950 pub fn register<F>(&mut self, provider_type: impl Into<DriverId>, factory: F)
1956 where
1957 F: Fn(&DriverConfig) -> BoxedChatDriver + Send + Sync + 'static,
1958 {
1959 self.register_descriptor(DriverDescriptor::chat_only(provider_type, factory));
1960 }
1961
1962 pub fn register_or_replace<F>(&mut self, provider_type: impl Into<DriverId>, factory: F)
1967 where
1968 F: Fn(&DriverConfig) -> BoxedChatDriver + Send + Sync + 'static,
1969 {
1970 self.register_descriptor_or_replace(DriverDescriptor::chat_only(provider_type, factory));
1971 }
1972
1973 pub fn register_external<F>(&mut self, id: impl Into<Arc<str>>, factory: F)
1978 where
1979 F: Fn(&DriverConfig) -> BoxedChatDriver + Send + Sync + 'static,
1980 {
1981 self.register(DriverId::external(id), factory);
1982 }
1983
1984 pub fn create_chat_driver(&self, config: &ProviderConfig) -> Result<BoxedChatDriver> {
1993 let requires_api_key = !matches!(
1997 config.provider_type,
1998 DriverId::LlmSim | DriverId::External(_) | DriverId::Mai
1999 );
2000 if requires_api_key && config.api_key.is_none() {
2001 return Err(AgentLoopError::llm(
2002 "API key is required. Configure the API key in provider settings.",
2003 ));
2004 }
2005
2006 let descriptor = self.descriptors.get(&config.provider_type).ok_or_else(|| {
2008 AgentLoopError::driver_not_registered(config.provider_type.to_string())
2009 })?;
2010 let factory = descriptor.chat.as_ref().ok_or_else(|| {
2011 AgentLoopError::llm(format!(
2012 "Provider driver '{}' does not implement the chat service.",
2013 config.provider_type
2014 ))
2015 })?;
2016
2017 let driver_config = DriverConfig::from_provider_config(config);
2019 Ok(factory(&driver_config))
2020 }
2021
2022 pub fn has_driver(&self, provider_type: &DriverId) -> bool {
2024 self.descriptors.contains_key(provider_type)
2025 }
2026
2027 pub fn descriptor(&self, provider_type: &DriverId) -> Option<&DriverDescriptor> {
2029 self.descriptors.get(provider_type)
2030 }
2031
2032 pub fn supports(&self, provider_type: &DriverId, service: ServiceKind) -> bool {
2034 self.descriptors
2035 .get(provider_type)
2036 .is_some_and(|d| d.supports(service))
2037 }
2038
2039 pub fn providers_for(&self, service: ServiceKind) -> Vec<DriverId> {
2041 self.descriptors
2042 .values()
2043 .filter(|d| d.supports(service))
2044 .map(|d| d.id.clone())
2045 .collect()
2046 }
2047
2048 pub fn registered_providers(&self) -> Vec<DriverId> {
2050 self.descriptors.keys().cloned().collect()
2051 }
2052
2053 pub fn create_embeddings_driver(
2061 &self,
2062 config: &ProviderConfig,
2063 ) -> std::result::Result<BoxedEmbeddingsDriver, EmbeddingsDriverError> {
2064 let requires_api_key = !matches!(
2065 config.provider_type,
2066 DriverId::LlmSim | DriverId::External(_)
2067 );
2068 if requires_api_key && config.api_key.is_none() {
2069 return Err(EmbeddingsDriverError::Provider(
2070 "API key is required. Configure the API key in provider settings.".to_string(),
2071 ));
2072 }
2073 let descriptor = self.descriptors.get(&config.provider_type).ok_or_else(|| {
2074 EmbeddingsDriverError::Provider(format!(
2075 "No driver registered for provider '{}'",
2076 config.provider_type
2077 ))
2078 })?;
2079 let factory = descriptor.embeddings.as_ref().ok_or_else(|| {
2080 EmbeddingsDriverError::Provider(format!(
2081 "Provider driver '{}' does not implement the embeddings service.",
2082 config.provider_type
2083 ))
2084 })?;
2085 let driver_config = DriverConfig::from_provider_config(config);
2086 Ok(factory(&driver_config))
2087 }
2088}
2089
2090const MAX_TOOL_RESULT_BYTES: usize = 64 * 1024;
2095
2096const TRUNCATION_SUFFIX: &str =
2097 "\n\n[Output truncated — exceeded 64 KiB limit. Try quiet flags, pipes, or redirect to file.]";
2098
2099fn truncate_tool_result(text: String) -> String {
2100 if text.len() <= MAX_TOOL_RESULT_BYTES {
2101 return text;
2102 }
2103 let content_budget = MAX_TOOL_RESULT_BYTES.saturating_sub(TRUNCATION_SUFFIX.len());
2104 let mut end = content_budget;
2105 while end > 0 && !text.is_char_boundary(end) {
2106 end -= 1;
2107 }
2108 let mut truncated = text[..end].to_string();
2109 truncated.push_str(TRUNCATION_SUFFIX);
2110 truncated
2111}
2112
2113#[cfg(test)]
2118mod tests {
2119 use super::*;
2120
2121 #[test]
2122 fn test_resolved_parallel_tool_calls_gating() {
2123 let mut config = LlmCallConfig::from(&RuntimeAgent::new("p", "gpt-5.2"));
2124
2125 assert_eq!(config.resolved_parallel_tool_calls(true), None);
2127 assert_eq!(config.resolved_parallel_tool_calls(false), None);
2128
2129 config.parallel_tool_calls = Some(true);
2131 assert_eq!(config.resolved_parallel_tool_calls(true), Some(true));
2132 assert_eq!(config.resolved_parallel_tool_calls(false), None);
2133
2134 config.parallel_tool_calls = Some(false);
2135 assert_eq!(config.resolved_parallel_tool_calls(true), Some(false));
2136 assert_eq!(config.resolved_parallel_tool_calls(false), None);
2137 }
2138
2139 #[test]
2140 fn test_chat_driver_default_omits_parallel_tool_calls() {
2141 struct DefaultDriver;
2143 #[async_trait]
2144 impl ChatDriver for DefaultDriver {
2145 async fn chat_completion_stream(
2146 &self,
2147 _messages: Vec<LlmMessage>,
2148 _config: &LlmCallConfig,
2149 ) -> Result<LlmResponseStream> {
2150 unreachable!()
2151 }
2152 }
2153 assert!(!DefaultDriver.supports_parallel_tool_calls("any-model"));
2154 }
2155
2156 #[test]
2157 fn test_fold_system_messages_none_when_absent() {
2158 let messages = vec![
2159 LlmMessage::text(LlmMessageRole::User, "hi"),
2160 LlmMessage::text(LlmMessageRole::Assistant, "ok"),
2161 ];
2162 assert_eq!(fold_system_messages(&messages), None);
2163 }
2164
2165 #[test]
2166 fn test_fold_system_messages_single() {
2167 let messages = vec![
2168 LlmMessage::text(LlmMessageRole::System, "AGENT-PROMPT"),
2169 LlmMessage::text(LlmMessageRole::User, "hi"),
2170 ];
2171 assert_eq!(
2172 fold_system_messages(&messages),
2173 Some("AGENT-PROMPT".to_string())
2174 );
2175 }
2176
2177 #[test]
2178 fn test_fold_system_messages_accumulates_in_order() {
2179 let messages = vec![
2183 LlmMessage::text(LlmMessageRole::System, "A"),
2184 LlmMessage::text(LlmMessageRole::User, "hi"),
2185 LlmMessage::text(LlmMessageRole::Assistant, "ok"),
2186 LlmMessage::text(LlmMessageRole::System, "B"),
2187 ];
2188 assert_eq!(fold_system_messages(&messages), Some("A\n\nB".to_string()));
2189 }
2190
2191 #[test]
2192 fn test_fold_system_messages_concatenates_parts() {
2193 let messages = vec![LlmMessage::parts(
2194 LlmMessageRole::System,
2195 vec![
2196 LlmContentPart::text("foo"),
2197 LlmContentPart::image("data:image/png;base64,xxx"),
2198 LlmContentPart::text("bar"),
2199 ],
2200 )];
2201 assert_eq!(fold_system_messages(&messages), Some("foobar".to_string()));
2202 }
2203
2204 #[test]
2205 fn test_llm_call_config_builder_from_runtime_agent() {
2206 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
2207 let llm_config = LlmCallConfigBuilder::from(&runtime_agent).build();
2208
2209 assert_eq!(llm_config.model, "gpt-4o");
2210 assert!(llm_config.reasoning_effort.is_none());
2211 assert!(llm_config.temperature.is_none());
2212 assert!(llm_config.max_tokens.is_none());
2213 assert!(llm_config.tools.is_empty());
2214 assert!(llm_config.metadata.is_empty());
2215 assert!(llm_config.openrouter_routing.is_none());
2217 }
2218
2219 #[test]
2220 fn runtime_agent_openrouter_routing_flows_into_call_config() {
2221 let mut runtime_agent = RuntimeAgent::new("You are helpful", "openai/gpt-5-mini");
2225 runtime_agent.openrouter_routing = Some(OpenRouterRoutingConfig {
2226 server_tools: vec![OpenRouterServerTool::new(
2227 OpenRouterServerToolKind::WebSearch,
2228 )],
2229 ..Default::default()
2230 });
2231
2232 let llm_config = LlmCallConfig::from(&runtime_agent);
2233 let routing = llm_config
2234 .openrouter_routing
2235 .expect("server-tool routing survives into the call config");
2236 assert_eq!(routing.server_tools.len(), 1);
2237 assert_eq!(
2238 routing.server_tools[0].kind.wire_type(),
2239 "openrouter:web_search"
2240 );
2241 }
2242
2243 #[test]
2244 fn test_llm_call_config_builder_with_metadata() {
2245 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
2246 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
2247 .with_metadata("session_id", "session_abc123")
2248 .with_metadata("agent_id", "agent_xyz789")
2249 .build();
2250
2251 assert_eq!(
2252 llm_config.metadata.get("session_id"),
2253 Some(&"session_abc123".to_string())
2254 );
2255 assert_eq!(
2256 llm_config.metadata.get("agent_id"),
2257 Some(&"agent_xyz789".to_string())
2258 );
2259 }
2260
2261 #[test]
2262 fn test_llm_call_config_builder_with_metadata_hashmap() {
2263 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
2264 let mut metadata = HashMap::new();
2265 metadata.insert("key1".to_string(), "value1".to_string());
2266 metadata.insert("key2".to_string(), "value2".to_string());
2267
2268 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
2269 .metadata(metadata)
2270 .build();
2271
2272 assert_eq!(llm_config.metadata.get("key1"), Some(&"value1".to_string()));
2273 assert_eq!(llm_config.metadata.get("key2"), Some(&"value2".to_string()));
2274 }
2275
2276 #[test]
2277 fn test_llm_call_config_builder_with_reasoning_effort() {
2278 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
2279 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
2280 .reasoning_effort("high")
2281 .build();
2282
2283 assert_eq!(llm_config.reasoning_effort, Some("high".to_string()));
2284 }
2285
2286 #[test]
2287 fn test_llm_call_config_builder_with_all_options() {
2288 let runtime_agent = RuntimeAgent::new("You are helpful", "gpt-4o");
2289 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
2290 .model("claude-3-opus")
2291 .reasoning_effort("medium")
2292 .temperature(0.7)
2293 .max_tokens(1000)
2294 .build();
2295
2296 assert_eq!(llm_config.model, "claude-3-opus");
2297 assert_eq!(llm_config.reasoning_effort, Some("medium".to_string()));
2298 assert_eq!(llm_config.temperature, Some(0.7));
2299 assert_eq!(llm_config.max_tokens, Some(1000));
2300 }
2301
2302 #[test]
2303 fn test_llm_call_config_builder_with_openrouter_routing() {
2304 let runtime_agent = RuntimeAgent::new("You are helpful", "openai/gpt-5-mini");
2305 let routing = OpenRouterRoutingConfig::fallback_models([
2306 "openai/gpt-5-mini",
2307 "anthropic/claude-sonnet-4.5",
2308 ]);
2309
2310 let llm_config = LlmCallConfigBuilder::from(&runtime_agent)
2311 .openrouter_routing(routing.clone())
2312 .build();
2313
2314 assert_eq!(llm_config.openrouter_routing, Some(routing));
2315 }
2316
2317 #[test]
2318 fn test_openrouter_fallback_models_empty_is_empty() {
2319 let routing = OpenRouterRoutingConfig::fallback_models(std::iter::empty::<String>());
2320
2321 assert!(routing.is_empty());
2322 assert_eq!(routing.route, None);
2323 }
2324
2325 #[test]
2326 fn test_openrouter_routing_validates_primary_model() {
2327 let routing = OpenRouterRoutingConfig::fallback_models([
2328 "openai/gpt-5-mini",
2329 "anthropic/claude-sonnet-4.5",
2330 ]);
2331
2332 assert!(
2333 routing
2334 .validate_for_primary_model("openai/gpt-5-mini")
2335 .is_ok()
2336 );
2337 let err = routing
2338 .validate_for_primary_model("anthropic/claude-sonnet-4.5")
2339 .unwrap_err();
2340 assert!(err.contains("models[0]"));
2341 }
2342
2343 #[test]
2344 fn test_openrouter_routing_rejects_fallback_without_models() {
2345 let routing = OpenRouterRoutingConfig {
2346 route: Some(OpenRouterRoute::Fallback),
2347 ..Default::default()
2348 };
2349
2350 let err = routing
2351 .validate_for_primary_model("openai/gpt-5-mini")
2352 .unwrap_err();
2353 assert!(err.contains("requires at least one model"));
2354 }
2355
2356 #[test]
2357 fn test_openrouter_routing_serializes_request_fields() {
2358 let routing = OpenRouterRoutingConfig {
2359 models: vec![
2360 "openai/gpt-5-mini".to_string(),
2361 "anthropic/claude-sonnet-4.5".to_string(),
2362 ],
2363 route: Some(OpenRouterRoute::Fallback),
2364 provider: Some(OpenRouterProviderRouting {
2365 order: vec!["anthropic".to_string(), "openai".to_string()],
2366 allow_fallbacks: Some(false),
2367 require_parameters: Some(true),
2368 data_collection: Some(OpenRouterDataCollection::Deny),
2369 zdr: Some(true),
2370 sort: Some(OpenRouterProviderSort::Advanced(
2371 OpenRouterProviderSortOptions {
2372 by: OpenRouterProviderSortBy::Throughput,
2373 partition: Some(OpenRouterSortPartition::None),
2374 },
2375 )),
2376 max_price: Some(OpenRouterMaxPrice {
2377 prompt: Some(1.0),
2378 completion: Some(2.0),
2379 ..Default::default()
2380 }),
2381 ..Default::default()
2382 }),
2383 ..Default::default()
2384 };
2385
2386 let json = serde_json::to_value(routing).unwrap();
2387
2388 assert_eq!(
2389 json,
2390 serde_json::json!({
2391 "models": [
2392 "openai/gpt-5-mini",
2393 "anthropic/claude-sonnet-4.5"
2394 ],
2395 "route": "fallback",
2396 "provider": {
2397 "order": ["anthropic", "openai"],
2398 "allow_fallbacks": false,
2399 "require_parameters": true,
2400 "data_collection": "deny",
2401 "zdr": true,
2402 "sort": {
2403 "by": "throughput",
2404 "partition": "none"
2405 },
2406 "max_price": {
2407 "prompt": 1.0,
2408 "completion": 2.0
2409 }
2410 }
2411 })
2412 );
2413 }
2414
2415 #[test]
2416 fn test_provider_type_parsing() {
2417 assert_eq!("openai".parse::<DriverId>().unwrap(), DriverId::OpenAI);
2418 assert_eq!(
2419 "openrouter".parse::<DriverId>().unwrap(),
2420 DriverId::OpenRouter
2421 );
2422 assert_eq!(
2423 "openai_completions".parse::<DriverId>().unwrap(),
2424 DriverId::OpenAICompletions
2425 );
2426 assert_eq!(
2427 "azure_openai".parse::<DriverId>().unwrap(),
2428 DriverId::AzureOpenAI
2429 );
2430 assert_eq!(
2431 "anthropic".parse::<DriverId>().unwrap(),
2432 DriverId::Anthropic
2433 );
2434 assert_eq!("gemini".parse::<DriverId>().unwrap(), DriverId::Gemini);
2435 assert_eq!(
2437 "ollama".parse::<DriverId>().unwrap(),
2438 DriverId::external("ollama")
2439 );
2440 assert_eq!(
2441 "custom".parse::<DriverId>().unwrap(),
2442 DriverId::external("custom")
2443 );
2444 }
2445
2446 #[test]
2447 fn test_external_provider_id_is_case_insensitive() {
2448 assert_eq!("OpenAI".parse::<DriverId>().unwrap(), DriverId::OpenAI);
2451 assert_eq!(
2452 "Ollama".parse::<DriverId>().unwrap(),
2453 "ollama".parse::<DriverId>().unwrap()
2454 );
2455 assert_eq!(DriverId::external("OpenAI-Codex").as_str(), "openai-codex");
2456 assert_eq!(
2458 DriverId::external("MyProvider"),
2459 "myprovider".parse::<DriverId>().unwrap()
2460 );
2461 }
2462
2463 #[test]
2464 fn test_provider_type_display() {
2465 assert_eq!(DriverId::OpenAI.to_string(), "openai");
2466 assert_eq!(DriverId::OpenRouter.to_string(), "openrouter");
2467 assert_eq!(DriverId::AzureOpenAI.to_string(), "azure_openai");
2468 assert_eq!(
2469 DriverId::OpenAICompletions.to_string(),
2470 "openai_completions"
2471 );
2472 assert_eq!(DriverId::Anthropic.to_string(), "anthropic");
2473 assert_eq!(DriverId::Gemini.to_string(), "gemini");
2474 }
2475
2476 #[test]
2477 fn test_provider_config_builder() {
2478 let config = ProviderConfig::new(DriverId::Anthropic)
2479 .with_api_key("test-key")
2480 .with_base_url("https://custom.api.com");
2481
2482 assert_eq!(config.provider_type, DriverId::Anthropic);
2483 assert_eq!(config.api_key, Some("test-key".to_string()));
2484 assert_eq!(config.base_url, Some("https://custom.api.com".to_string()));
2485 }
2486
2487 #[test]
2488 fn test_driver_registry_requires_api_key() {
2489 let mut registry = DriverRegistry::new();
2491 registry.register(DriverId::OpenAI, |_config| {
2492 struct MockDriver;
2494 #[async_trait]
2495 impl ChatDriver for MockDriver {
2496 async fn chat_completion_stream(
2497 &self,
2498 _messages: Vec<LlmMessage>,
2499 _config: &LlmCallConfig,
2500 ) -> Result<LlmResponseStream> {
2501 unimplemented!()
2502 }
2503 }
2504 Box::new(MockDriver)
2505 });
2506
2507 let config = ProviderConfig::new(DriverId::OpenAI);
2509 let result = registry.create_chat_driver(&config);
2510 assert!(result.is_err());
2511
2512 let config_with_key = ProviderConfig::new(DriverId::OpenAI).with_api_key("test-key");
2514 let result = registry.create_chat_driver(&config_with_key);
2515 assert!(result.is_ok());
2516 }
2517
2518 #[test]
2519 fn test_driver_registry_returns_error_for_unregistered_provider() {
2520 let registry = DriverRegistry::new();
2521 let config = ProviderConfig::new(DriverId::Anthropic).with_api_key("test-key");
2522
2523 let result = registry.create_chat_driver(&config);
2524
2525 if let Err(AgentLoopError::DriverNotRegistered(provider)) = result {
2527 assert_eq!(provider, "anthropic");
2528 } else {
2529 panic!("Expected DriverNotRegistered error");
2530 }
2531 }
2532
2533 #[test]
2534 fn test_driver_registry_registration() {
2535 let mut registry = DriverRegistry::new();
2536
2537 assert!(!registry.has_driver(&DriverId::OpenAI));
2538 assert!(!registry.has_driver(&DriverId::Anthropic));
2539
2540 registry.register(DriverId::OpenAI, |_config| {
2541 struct MockDriver;
2542 #[async_trait]
2543 impl ChatDriver for MockDriver {
2544 async fn chat_completion_stream(
2545 &self,
2546 _messages: Vec<LlmMessage>,
2547 _config: &LlmCallConfig,
2548 ) -> Result<LlmResponseStream> {
2549 unimplemented!()
2550 }
2551 }
2552 Box::new(MockDriver)
2553 });
2554
2555 assert!(registry.has_driver(&DriverId::OpenAI));
2556 assert!(!registry.has_driver(&DriverId::Anthropic));
2557 }
2558
2559 #[test]
2560 fn test_register_external_and_create_driver_without_api_key() {
2561 struct MockDriver;
2562 #[async_trait]
2563 impl ChatDriver for MockDriver {
2564 async fn chat_completion_stream(
2565 &self,
2566 _messages: Vec<LlmMessage>,
2567 _config: &LlmCallConfig,
2568 ) -> Result<LlmResponseStream> {
2569 unimplemented!()
2570 }
2571 }
2572
2573 let mut registry = DriverRegistry::new();
2574 registry.register_external("openai-codex", |config| {
2575 assert_eq!(config.provider_type, DriverId::external("openai-codex"));
2577 Box::new(MockDriver)
2578 });
2579
2580 assert!(registry.has_driver(&DriverId::external("openai-codex")));
2581
2582 let config = ProviderConfig::new(DriverId::external("openai-codex")).with_metadata(
2584 ProviderMetadata {
2585 refresh_token: Some("rt".into()),
2586 ..Default::default()
2587 },
2588 );
2589 assert!(registry.create_chat_driver(&config).is_ok());
2590 }
2591
2592 #[test]
2593 fn test_register_defaults_to_chat_only_descriptor() {
2594 struct MockDriver;
2595 #[async_trait]
2596 impl ChatDriver for MockDriver {
2597 async fn chat_completion_stream(
2598 &self,
2599 _messages: Vec<LlmMessage>,
2600 _config: &LlmCallConfig,
2601 ) -> Result<LlmResponseStream> {
2602 unimplemented!()
2603 }
2604 }
2605
2606 let mut registry = DriverRegistry::new();
2607 registry.register(DriverId::Anthropic, |_config| Box::new(MockDriver));
2608
2609 let descriptor = registry.descriptor(&DriverId::Anthropic).unwrap();
2610 assert_eq!(descriptor.display_name, "Anthropic");
2611 assert_eq!(descriptor.services, vec![ServiceKind::Chat]);
2612 assert!(descriptor.chat.is_some());
2613 assert_eq!(descriptor.credential_schema.fields.len(), 1);
2615 assert_eq!(descriptor.credential_schema.fields[0].name, "api_key");
2616 assert!(descriptor.credential_schema.fields[0].required);
2617
2618 registry.register(DriverId::LlmSim, |_config| Box::new(MockDriver));
2620 let sim = registry.descriptor(&DriverId::LlmSim).unwrap();
2621 assert!(sim.credential_schema.fields.is_empty());
2622 }
2623
2624 #[test]
2625 fn test_descriptor_services_and_lookup() {
2626 struct MockDriver;
2627 #[async_trait]
2628 impl ChatDriver for MockDriver {
2629 async fn chat_completion_stream(
2630 &self,
2631 _messages: Vec<LlmMessage>,
2632 _config: &LlmCallConfig,
2633 ) -> Result<LlmResponseStream> {
2634 unimplemented!()
2635 }
2636 }
2637
2638 let mut registry = DriverRegistry::new();
2639 registry.register_descriptor(DriverDescriptor {
2640 services: vec![ServiceKind::Chat, ServiceKind::Realtime],
2641 ..DriverDescriptor::chat_only(DriverId::OpenAI, |_config| Box::new(MockDriver))
2642 });
2643 registry.register(DriverId::Anthropic, |_config| Box::new(MockDriver));
2644
2645 assert!(registry.supports(&DriverId::OpenAI, ServiceKind::Chat));
2646 assert!(registry.supports(&DriverId::OpenAI, ServiceKind::Realtime));
2647 assert!(!registry.supports(&DriverId::Anthropic, ServiceKind::Realtime));
2648 assert!(!registry.supports(&DriverId::Gemini, ServiceKind::Chat));
2649
2650 let realtime = registry.providers_for(ServiceKind::Realtime);
2651 assert_eq!(realtime, vec![DriverId::OpenAI]);
2652 let mut chat = registry.providers_for(ServiceKind::Chat);
2653 chat.sort_by_key(|p| p.to_string());
2654 assert_eq!(chat, vec![DriverId::Anthropic, DriverId::OpenAI]);
2655 }
2656
2657 #[test]
2658 fn test_create_chat_driver_fails_without_chat_factory() {
2659 let mut registry = DriverRegistry::new();
2660 registry.register_descriptor(DriverDescriptor {
2661 id: DriverId::external("embeddings-only"),
2662 display_name: "Embeddings Only".to_string(),
2663 services: vec![ServiceKind::Embeddings],
2664 credential_schema: CredentialFormSchema::empty(),
2665 oauth: None,
2666 chat: None,
2667 embeddings: None,
2668 });
2669
2670 let config = ProviderConfig::new(DriverId::external("embeddings-only"));
2671 let err = match registry.create_chat_driver(&config) {
2672 Ok(_) => panic!("expected error for missing chat factory"),
2673 Err(err) => err,
2674 };
2675 assert!(
2676 err.to_string()
2677 .contains("does not implement the chat service"),
2678 "unexpected error: {err}"
2679 );
2680 }
2681
2682 #[test]
2683 #[should_panic(expected = "already registered")]
2684 fn test_register_duplicate_panics() {
2685 struct MockDriver;
2686 #[async_trait]
2687 impl ChatDriver for MockDriver {
2688 async fn chat_completion_stream(
2689 &self,
2690 _messages: Vec<LlmMessage>,
2691 _config: &LlmCallConfig,
2692 ) -> Result<LlmResponseStream> {
2693 unimplemented!()
2694 }
2695 }
2696
2697 let mut registry = DriverRegistry::new();
2698 registry.register(DriverId::OpenAI, |_config| Box::new(MockDriver));
2699 registry.register(DriverId::OpenAI, |_config| Box::new(MockDriver));
2701 }
2702
2703 #[test]
2704 fn test_register_or_replace_overwrites() {
2705 struct MockDriver;
2706 #[async_trait]
2707 impl ChatDriver for MockDriver {
2708 async fn chat_completion_stream(
2709 &self,
2710 _messages: Vec<LlmMessage>,
2711 _config: &LlmCallConfig,
2712 ) -> Result<LlmResponseStream> {
2713 unimplemented!()
2714 }
2715 }
2716
2717 let mut registry = DriverRegistry::new();
2718 registry.register(DriverId::LlmSim, |_config| Box::new(MockDriver));
2719 registry.register_or_replace(DriverId::LlmSim, |_config| Box::new(MockDriver));
2721 assert!(registry.has_driver(&DriverId::LlmSim));
2722 }
2723
2724 use crate::{ContentPart, ImageFileContentPart, Message, MessageRole, TextContentPart};
2729
2730 #[test]
2731 fn test_message_has_image_files_with_image_file() {
2732 let message = Message {
2733 id: uuid::Uuid::new_v4().into(),
2734 role: MessageRole::User,
2735 content: vec![
2736 ContentPart::Text(TextContentPart {
2737 text: "Look at this image".to_string(),
2738 }),
2739 ContentPart::ImageFile(ImageFileContentPart {
2740 image_id: uuid::Uuid::new_v4().into(),
2741 filename: Some("test.png".to_string()),
2742 }),
2743 ],
2744 phase: None,
2745 thinking: None,
2746 thinking_signature: None,
2747 controls: None,
2748 metadata: None,
2749 external_actor: None,
2750 created_at: chrono::Utc::now(),
2751 };
2752
2753 assert!(LlmMessage::message_has_image_files(&message));
2754 }
2755
2756 #[test]
2757 fn test_message_has_image_files_without_image_file() {
2758 let message = Message {
2759 id: uuid::Uuid::new_v4().into(),
2760 role: MessageRole::User,
2761 content: vec![ContentPart::Text(TextContentPart {
2762 text: "Just text".to_string(),
2763 })],
2764 phase: None,
2765 thinking: None,
2766 thinking_signature: None,
2767 controls: None,
2768 metadata: None,
2769 external_actor: None,
2770 created_at: chrono::Utc::now(),
2771 };
2772
2773 assert!(!LlmMessage::message_has_image_files(&message));
2774 }
2775
2776 #[test]
2777 fn test_extract_image_file_ids() {
2778 let id1 = uuid::Uuid::new_v4();
2779 let id2 = uuid::Uuid::new_v4();
2780
2781 let message = Message {
2782 id: uuid::Uuid::new_v4().into(),
2783 role: MessageRole::User,
2784 content: vec![
2785 ContentPart::Text(TextContentPart {
2786 text: "Look at these images".to_string(),
2787 }),
2788 ContentPart::ImageFile(ImageFileContentPart {
2789 image_id: id1.into(),
2790 filename: Some("test1.png".to_string()),
2791 }),
2792 ContentPart::ImageFile(ImageFileContentPart {
2793 image_id: id2.into(),
2794 filename: Some("test2.png".to_string()),
2795 }),
2796 ],
2797 phase: None,
2798 thinking: None,
2799 thinking_signature: None,
2800 controls: None,
2801 metadata: None,
2802 external_actor: None,
2803 created_at: chrono::Utc::now(),
2804 };
2805
2806 let ids = LlmMessage::extract_image_file_ids(&message);
2807 assert_eq!(ids.len(), 2);
2808 assert!(ids.contains(&id1));
2809 assert!(ids.contains(&id2));
2810 }
2811
2812 #[test]
2813 fn test_from_message_with_images_text_only() {
2814 let message = Message {
2815 id: uuid::Uuid::new_v4().into(),
2816 role: MessageRole::User,
2817 content: vec![ContentPart::Text(TextContentPart {
2818 text: "Hello".to_string(),
2819 })],
2820 phase: None,
2821 thinking: None,
2822 thinking_signature: None,
2823 controls: None,
2824 metadata: None,
2825 external_actor: None,
2826 created_at: chrono::Utc::now(),
2827 };
2828
2829 let resolved = std::collections::HashMap::new();
2830 let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
2831
2832 assert_eq!(llm_message.role, LlmMessageRole::User);
2833 match llm_message.content {
2834 LlmMessageContent::Text(text) => assert_eq!(text, "Hello"),
2835 _ => panic!("Expected text content"),
2836 }
2837 }
2838
2839 #[test]
2840 fn test_from_message_with_images_resolved_image() {
2841 let image_id = uuid::Uuid::new_v4();
2842 let message = Message {
2843 id: uuid::Uuid::new_v4().into(),
2844 role: MessageRole::User,
2845 content: vec![
2846 ContentPart::Text(TextContentPart {
2847 text: "Look at this".to_string(),
2848 }),
2849 ContentPart::ImageFile(ImageFileContentPart {
2850 image_id: image_id.into(),
2851 filename: Some("test.png".to_string()),
2852 }),
2853 ],
2854 phase: None,
2855 thinking: None,
2856 thinking_signature: None,
2857 controls: None,
2858 metadata: None,
2859 external_actor: None,
2860 created_at: chrono::Utc::now(),
2861 };
2862
2863 let mut resolved = std::collections::HashMap::new();
2864 resolved.insert(
2865 image_id,
2866 crate::ResolvedImage::new("base64data", "image/png"),
2867 );
2868
2869 let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
2870
2871 match &llm_message.content {
2872 LlmMessageContent::Parts(parts) => {
2873 assert_eq!(parts.len(), 2);
2874 assert!(matches!(&parts[0], LlmContentPart::Text { .. }));
2876 if let LlmContentPart::Image { url } = &parts[1] {
2878 assert!(url.starts_with("data:image/png;base64,"));
2879 } else {
2880 panic!("Expected image content part");
2881 }
2882 }
2883 _ => panic!("Expected parts content"),
2884 }
2885 }
2886
2887 #[test]
2888 fn test_from_message_with_images_unresolved_image() {
2889 let image_id = uuid::Uuid::new_v4();
2890 let message = Message {
2891 id: uuid::Uuid::new_v4().into(),
2892 role: MessageRole::User,
2893 content: vec![ContentPart::ImageFile(ImageFileContentPart {
2894 image_id: image_id.into(),
2895 filename: Some("missing.png".to_string()),
2896 })],
2897 phase: None,
2898 thinking: None,
2899 thinking_signature: None,
2900 controls: None,
2901 metadata: None,
2902 external_actor: None,
2903 created_at: chrono::Utc::now(),
2904 };
2905
2906 let resolved = std::collections::HashMap::new();
2908 let llm_message = LlmMessage::from_message_with_images(&message, &resolved);
2909
2910 match &llm_message.content {
2913 LlmMessageContent::Text(text) => {
2914 assert!(text.contains("Image not found"));
2915 }
2916 LlmMessageContent::Parts(parts) => {
2917 assert_eq!(parts.len(), 1);
2918 if let LlmContentPart::Text { text } = &parts[0] {
2919 assert!(text.contains("Image not found"));
2920 } else {
2921 panic!("Expected text placeholder for missing image");
2922 }
2923 }
2924 }
2925 }
2926
2927 #[test]
2928 fn test_prepend_text_prefix_simple_text() {
2929 let mut msg = LlmMessage::text(LlmMessageRole::User, "Hello bot");
2930 msg.prepend_text_prefix("[Alice] ");
2931 assert_eq!(msg.content_as_text(), "[Alice] Hello bot");
2932 }
2933
2934 #[test]
2935 fn test_prepend_text_prefix_parts() {
2936 let mut msg = LlmMessage::parts(
2937 LlmMessageRole::User,
2938 vec![
2939 LlmContentPart::Text {
2940 text: "Hello".to_string(),
2941 },
2942 LlmContentPart::Image {
2943 url: "data:image/png;base64,abc".to_string(),
2944 },
2945 ],
2946 );
2947 msg.prepend_text_prefix("[Bob] ");
2948 match &msg.content {
2949 LlmMessageContent::Parts(parts) => {
2950 if let LlmContentPart::Text { text } = &parts[0] {
2951 assert_eq!(text, "[Bob] Hello");
2952 } else {
2953 panic!("Expected text part");
2954 }
2955 }
2956 _ => panic!("Expected parts content"),
2957 }
2958 }
2959
2960 #[test]
2961 fn test_prepend_text_prefix_parts_no_text() {
2962 let mut msg = LlmMessage::parts(
2963 LlmMessageRole::User,
2964 vec![LlmContentPart::Image {
2965 url: "data:image/png;base64,abc".to_string(),
2966 }],
2967 );
2968 msg.prepend_text_prefix("[Eve] ");
2969 match &msg.content {
2970 LlmMessageContent::Parts(parts) => {
2971 assert_eq!(parts.len(), 2);
2972 if let LlmContentPart::Text { text } = &parts[0] {
2973 assert_eq!(text, "[Eve] ");
2974 } else {
2975 panic!("Expected prepended text part");
2976 }
2977 }
2978 _ => panic!("Expected parts content"),
2979 }
2980 }
2981
2982 #[test]
2983 fn test_openrouter_plugin_config_is_empty() {
2984 assert!(OpenRouterPluginConfig::default().is_empty());
2985 assert!(
2986 !OpenRouterPluginConfig {
2987 web: Some(OpenRouterWebSearchPlugin::default()),
2988 file: None,
2989 }
2990 .is_empty()
2991 );
2992 assert!(
2993 !OpenRouterPluginConfig {
2994 web: None,
2995 file: Some(OpenRouterFilePlugin {}),
2996 }
2997 .is_empty()
2998 );
2999 }
3000
3001 #[test]
3002 fn test_openrouter_routing_is_empty_with_plugins() {
3003 let with_plugins = OpenRouterRoutingConfig {
3004 plugins: Some(OpenRouterPluginConfig {
3005 web: Some(OpenRouterWebSearchPlugin::default()),
3006 file: None,
3007 }),
3008 ..Default::default()
3009 };
3010 assert!(!with_plugins.is_empty());
3011
3012 let empty_plugins = OpenRouterRoutingConfig {
3013 plugins: Some(OpenRouterPluginConfig::default()),
3014 ..Default::default()
3015 };
3016 assert!(empty_plugins.is_empty());
3017 }
3018
3019 #[test]
3020 fn test_openrouter_web_search_plugin_serialization() {
3021 let plugin = OpenRouterWebSearchPlugin {
3022 max_results: Some(10),
3023 search_prompt: Some("search for Rust crates".to_string()),
3024 };
3025 let json = serde_json::to_value(&plugin).unwrap();
3026 assert_eq!(json["max_results"], 10);
3027 assert_eq!(json["search_prompt"], "search for Rust crates");
3028 }
3029
3030 #[test]
3031 fn test_openrouter_web_search_plugin_omits_none_fields() {
3032 let plugin = OpenRouterWebSearchPlugin::default();
3033 let json = serde_json::to_value(&plugin).unwrap();
3034 assert!(json.get("max_results").is_none());
3035 assert!(json.get("search_prompt").is_none());
3036 }
3037
3038 #[test]
3039 fn test_capacity_strategy_shared_capacity_is_noop() {
3040 let base = OpenRouterRoutingConfig {
3041 models: vec!["openai/gpt-5-mini".to_string()],
3042 capacity_strategy: Some(OpenRouterCapacityStrategy::SharedCapacity),
3043 ..Default::default()
3044 };
3045 let result = base.apply_capacity_strategy().unwrap();
3046 assert_eq!(
3047 result.capacity_strategy,
3048 Some(OpenRouterCapacityStrategy::SharedCapacity)
3049 );
3050 assert!(result.provider.is_none());
3051 }
3052
3053 #[test]
3054 fn test_capacity_strategy_none_is_noop() {
3055 let base = OpenRouterRoutingConfig {
3056 models: vec!["openai/gpt-5-mini".to_string()],
3057 capacity_strategy: None,
3058 ..Default::default()
3059 };
3060 let result = base.apply_capacity_strategy().unwrap();
3061 assert!(result.provider.is_none());
3062 }
3063
3064 #[test]
3065 fn test_capacity_strategy_byok_first_sets_allow_fallbacks() {
3066 let base = OpenRouterRoutingConfig {
3067 models: vec!["openai/gpt-5-mini".to_string()],
3068 capacity_strategy: Some(OpenRouterCapacityStrategy::ByokFirst),
3069 ..Default::default()
3070 };
3071 let result = base.apply_capacity_strategy().unwrap();
3072 let provider = result.provider.as_ref().expect("provider set by ByokFirst");
3073 assert_eq!(provider.allow_fallbacks, Some(true));
3074 }
3075
3076 #[test]
3077 fn test_capacity_strategy_byok_first_preserves_explicit_allow_fallbacks() {
3078 let base = OpenRouterRoutingConfig {
3080 models: vec!["openai/gpt-5-mini".to_string()],
3081 capacity_strategy: Some(OpenRouterCapacityStrategy::ByokFirst),
3082 provider: Some(OpenRouterProviderRouting {
3083 allow_fallbacks: Some(false),
3084 ..Default::default()
3085 }),
3086 ..Default::default()
3087 };
3088 let result = base.apply_capacity_strategy().unwrap();
3089 let provider = result.provider.as_ref().unwrap();
3090 assert_eq!(provider.allow_fallbacks, Some(false));
3091 }
3092
3093 #[test]
3094 fn test_capacity_strategy_byok_only_requires_provider_only() {
3095 let base = OpenRouterRoutingConfig {
3096 models: vec!["openai/gpt-5-mini".to_string()],
3097 capacity_strategy: Some(OpenRouterCapacityStrategy::ByokOnly),
3098 ..Default::default()
3099 };
3100 let err = base.apply_capacity_strategy().unwrap_err();
3101 assert!(
3102 err.contains("provider.only"),
3103 "error should mention provider.only: {err}"
3104 );
3105 }
3106
3107 #[test]
3108 fn test_capacity_strategy_byok_only_disables_fallbacks() {
3109 let base = OpenRouterRoutingConfig {
3110 models: vec!["openai/gpt-5-mini".to_string()],
3111 capacity_strategy: Some(OpenRouterCapacityStrategy::ByokOnly),
3112 provider: Some(OpenRouterProviderRouting {
3113 only: vec!["my-byok-provider".to_string()],
3114 ..Default::default()
3115 }),
3116 ..Default::default()
3117 };
3118 let result = base.apply_capacity_strategy().unwrap();
3119 let provider = result.provider.as_ref().unwrap();
3120 assert_eq!(provider.allow_fallbacks, Some(false));
3121 assert_eq!(provider.only, vec!["my-byok-provider"]);
3122 }
3123
3124 #[test]
3125 fn test_capacity_strategy_byok_only_not_empty_in_is_empty() {
3126 let with_strategy = OpenRouterRoutingConfig {
3127 capacity_strategy: Some(OpenRouterCapacityStrategy::ByokOnly),
3128 ..Default::default()
3129 };
3130 assert!(!with_strategy.is_empty());
3131
3132 let byok_first = OpenRouterRoutingConfig {
3133 capacity_strategy: Some(OpenRouterCapacityStrategy::ByokFirst),
3134 ..Default::default()
3135 };
3136 assert!(!byok_first.is_empty());
3137
3138 let shared = OpenRouterRoutingConfig {
3139 capacity_strategy: Some(OpenRouterCapacityStrategy::SharedCapacity),
3140 ..Default::default()
3141 };
3142 assert!(shared.is_empty());
3143 }
3144
3145 #[test]
3150 fn test_preset_no_presets_is_noop() {
3151 let base = OpenRouterRoutingConfig {
3152 models: vec!["openai/gpt-5-mini".to_string()],
3153 ..Default::default()
3154 };
3155 let result = base.apply_presets().unwrap();
3156 assert_eq!(result, base);
3157 }
3158
3159 #[test]
3160 fn test_preset_cheapest_with_tools_sets_require_parameters_and_sort_price() {
3161 let base = OpenRouterRoutingConfig {
3162 presets: vec![OpenRouterRoutingPreset::CheapestWithTools],
3163 ..Default::default()
3164 };
3165 let result = base.apply_presets().unwrap();
3166 assert!(result.presets.is_empty(), "presets cleared after apply");
3167 let provider = result.provider.expect("provider set by preset");
3168 assert_eq!(provider.require_parameters, Some(true));
3169 assert_eq!(
3170 provider.sort,
3171 Some(OpenRouterProviderSort::Simple(
3172 OpenRouterProviderSortBy::Price
3173 ))
3174 );
3175 }
3176
3177 #[test]
3178 fn test_preset_lowest_latency_review_sets_sort_throughput() {
3179 let base = OpenRouterRoutingConfig {
3180 presets: vec![OpenRouterRoutingPreset::LowestLatencyReview],
3181 ..Default::default()
3182 };
3183 let result = base.apply_presets().unwrap();
3184 let provider = result.provider.expect("provider set by preset");
3185 assert_eq!(
3186 provider.sort,
3187 Some(OpenRouterProviderSort::Simple(
3188 OpenRouterProviderSortBy::Throughput
3189 ))
3190 );
3191 }
3192
3193 #[test]
3194 fn test_preset_zdr_only_sets_zdr() {
3195 let base = OpenRouterRoutingConfig {
3196 presets: vec![OpenRouterRoutingPreset::ZdrOnly],
3197 ..Default::default()
3198 };
3199 let result = base.apply_presets().unwrap();
3200 let provider = result.provider.expect("provider set");
3201 assert_eq!(provider.zdr, Some(true));
3202 }
3203
3204 #[test]
3205 fn test_preset_byok_first_sets_allow_fallbacks() {
3206 let base = OpenRouterRoutingConfig {
3207 presets: vec![OpenRouterRoutingPreset::ByokFirst],
3208 ..Default::default()
3209 };
3210 let result = base.apply_presets().unwrap();
3211 let provider = result.provider.expect("provider set");
3212 assert_eq!(provider.allow_fallbacks, Some(true));
3213 }
3214
3215 #[test]
3216 fn test_preset_no_data_collection_sets_data_collection_deny() {
3217 let base = OpenRouterRoutingConfig {
3218 presets: vec![OpenRouterRoutingPreset::NoDataCollection],
3219 ..Default::default()
3220 };
3221 let result = base.apply_presets().unwrap();
3222 let provider = result.provider.expect("provider set");
3223 assert_eq!(
3224 provider.data_collection,
3225 Some(OpenRouterDataCollection::Deny)
3226 );
3227 }
3228
3229 #[test]
3230 fn test_preset_strict_json_sets_require_parameters() {
3231 let base = OpenRouterRoutingConfig {
3232 presets: vec![OpenRouterRoutingPreset::StrictJson],
3233 ..Default::default()
3234 };
3235 let result = base.apply_presets().unwrap();
3236 let provider = result.provider.expect("provider set");
3237 assert_eq!(provider.require_parameters, Some(true));
3238 }
3239
3240 #[test]
3241 fn test_preset_reasoning_required_sets_require_parameters() {
3242 let base = OpenRouterRoutingConfig {
3243 presets: vec![OpenRouterRoutingPreset::ReasoningRequired],
3244 ..Default::default()
3245 };
3246 let result = base.apply_presets().unwrap();
3247 let provider = result.provider.expect("provider set");
3248 assert_eq!(provider.require_parameters, Some(true));
3249 }
3250
3251 #[test]
3252 fn test_preset_max_price_converts_usd_per_million() {
3253 let base = OpenRouterRoutingConfig {
3254 presets: vec![OpenRouterRoutingPreset::MaxPrice {
3255 prompt_usd_per_million: Some(5.0),
3256 completion_usd_per_million: Some(15.0),
3257 }],
3258 ..Default::default()
3259 };
3260 let result = base.apply_presets().unwrap();
3261 let provider = result.provider.expect("provider set");
3262 let max_price = provider.max_price.expect("max_price set");
3263 let prompt = max_price.prompt.expect("prompt set");
3265 assert!((prompt - 5.0 / 1_000_000.0).abs() < f64::EPSILON);
3266 let completion = max_price.completion.expect("completion set");
3267 assert!((completion - 15.0 / 1_000_000.0).abs() < f64::EPSILON);
3268 }
3269
3270 #[test]
3271 fn test_preset_max_price_rejects_negative_values() {
3272 let base = OpenRouterRoutingConfig {
3273 presets: vec![OpenRouterRoutingPreset::MaxPrice {
3274 prompt_usd_per_million: Some(-1.0),
3275 completion_usd_per_million: None,
3276 }],
3277 ..Default::default()
3278 };
3279 let err = base.apply_presets().unwrap_err();
3280 assert!(
3281 err.contains("non-negative"),
3282 "error should mention non-negative: {err}"
3283 );
3284 }
3285
3286 #[test]
3287 fn test_preset_max_price_both_none_no_provider_field() {
3288 let base = OpenRouterRoutingConfig {
3289 presets: vec![OpenRouterRoutingPreset::MaxPrice {
3290 prompt_usd_per_million: None,
3291 completion_usd_per_million: None,
3292 }],
3293 ..Default::default()
3294 };
3295 let result = base.apply_presets().unwrap();
3296 assert!(
3297 result.provider.is_none(),
3298 "MaxPrice with no dimensions should not produce a provider field"
3299 );
3300 }
3301
3302 #[test]
3303 fn test_preset_explicit_provider_overrides_preset() {
3304 let base = OpenRouterRoutingConfig {
3305 presets: vec![OpenRouterRoutingPreset::CheapestWithTools],
3306 provider: Some(OpenRouterProviderRouting {
3307 sort: Some(OpenRouterProviderSort::Simple(
3309 OpenRouterProviderSortBy::Throughput,
3310 )),
3311 ..Default::default()
3312 }),
3313 ..Default::default()
3314 };
3315 let result = base.apply_presets().unwrap();
3316 let provider = result.provider.expect("provider set");
3317 assert_eq!(
3319 provider.sort,
3320 Some(OpenRouterProviderSort::Simple(
3321 OpenRouterProviderSortBy::Throughput
3322 ))
3323 );
3324 assert_eq!(provider.require_parameters, Some(true));
3326 }
3327
3328 #[test]
3329 fn test_preset_multiple_presets_combined() {
3330 let base = OpenRouterRoutingConfig {
3331 presets: vec![
3332 OpenRouterRoutingPreset::ZdrOnly,
3333 OpenRouterRoutingPreset::NoDataCollection,
3334 OpenRouterRoutingPreset::LowestLatencyReview,
3335 ],
3336 ..Default::default()
3337 };
3338 let result = base.apply_presets().unwrap();
3339 let provider = result.provider.expect("provider set");
3340 assert_eq!(provider.zdr, Some(true));
3341 assert_eq!(
3342 provider.data_collection,
3343 Some(OpenRouterDataCollection::Deny)
3344 );
3345 assert_eq!(
3346 provider.sort,
3347 Some(OpenRouterProviderSort::Simple(
3348 OpenRouterProviderSortBy::Throughput
3349 ))
3350 );
3351 }
3352
3353 #[test]
3354 fn test_preset_later_preset_overrides_sort() {
3355 let base = OpenRouterRoutingConfig {
3356 presets: vec![
3357 OpenRouterRoutingPreset::CheapestWithTools, OpenRouterRoutingPreset::LowestLatencyReview, ],
3360 ..Default::default()
3361 };
3362 let result = base.apply_presets().unwrap();
3363 let provider = result.provider.expect("provider set");
3364 assert_eq!(
3366 provider.sort,
3367 Some(OpenRouterProviderSort::Simple(
3368 OpenRouterProviderSortBy::Throughput
3369 ))
3370 );
3371 assert_eq!(provider.require_parameters, Some(true));
3373 }
3374
3375 #[test]
3376 fn test_preset_non_empty_in_is_empty() {
3377 let with_preset = OpenRouterRoutingConfig {
3378 presets: vec![OpenRouterRoutingPreset::ZdrOnly],
3379 ..Default::default()
3380 };
3381 assert!(!with_preset.is_empty());
3382
3383 let without = OpenRouterRoutingConfig::default();
3384 assert!(without.is_empty());
3385 }
3386}