1use anyhow::Result;
2use futures::Stream;
3use serde::{Deserialize, Serialize};
4
5use super::canonical::{map_to_canonical_model, CanonicalModelRegistry};
6use super::errors::ProviderError;
7use super::retry::RetryConfig;
8use crate::config::base::ConfigValue;
9use crate::conversation::message::Message;
10use crate::conversation::Conversation;
11use crate::model::ModelConfig;
12use crate::utils::safe_truncate;
13use rmcp::model::Tool;
14use utoipa::ToSchema;
15
16use once_cell::sync::Lazy;
17use std::ops::{Add, AddAssign};
18use std::pin::Pin;
19use std::sync::Mutex;
20
21pub static CURRENT_MODEL: Lazy<Mutex<Option<String>>> = Lazy::new(|| Mutex::new(None));
23
24pub fn set_current_model(model: &str) {
26 if let Ok(mut current_model) = CURRENT_MODEL.lock() {
27 *current_model = Some(model.to_string());
28 }
29}
30
31pub fn get_current_model() -> Option<String> {
33 CURRENT_MODEL.lock().ok().and_then(|model| model.clone())
34}
35
36pub static MSG_COUNT_FOR_SESSION_NAME_GENERATION: usize = 3;
37
38#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)]
40pub struct ModelInfo {
41 pub name: String,
43 pub context_limit: usize,
45 pub input_token_cost: Option<f64>,
47 pub output_token_cost: Option<f64>,
49 pub currency: Option<String>,
51 pub supports_cache_control: Option<bool>,
53}
54
55impl ModelInfo {
56 pub fn new(name: impl Into<String>, context_limit: usize) -> Self {
58 Self {
59 name: name.into(),
60 context_limit,
61 input_token_cost: None,
62 output_token_cost: None,
63 currency: None,
64 supports_cache_control: None,
65 }
66 }
67
68 pub fn with_cost(
70 name: impl Into<String>,
71 context_limit: usize,
72 input_cost: f64,
73 output_cost: f64,
74 ) -> Self {
75 Self {
76 name: name.into(),
77 context_limit,
78 input_token_cost: Some(input_cost),
79 output_token_cost: Some(output_cost),
80 currency: Some("$".to_string()),
81 supports_cache_control: None,
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
87pub enum ProviderType {
88 Preferred,
89 Builtin,
90 Declarative,
91 Custom,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
96pub struct ProviderMetadata {
97 pub name: String,
99 pub display_name: String,
101 pub description: String,
103 pub default_model: String,
105 pub known_models: Vec<ModelInfo>,
107 pub model_doc_link: String,
109 pub config_keys: Vec<ConfigKey>,
111}
112
113impl ProviderMetadata {
114 pub fn new(
115 name: &str,
116 display_name: &str,
117 description: &str,
118 default_model: &str,
119 model_names: Vec<&str>,
120 model_doc_link: &str,
121 config_keys: Vec<ConfigKey>,
122 ) -> Self {
123 Self {
124 name: name.to_string(),
125 display_name: display_name.to_string(),
126 description: description.to_string(),
127 default_model: default_model.to_string(),
128 known_models: model_names
129 .iter()
130 .map(|&name| ModelInfo {
131 name: name.to_string(),
132 context_limit: ModelConfig::new_or_fail(name).context_limit(),
133 input_token_cost: None,
134 output_token_cost: None,
135 currency: None,
136 supports_cache_control: None,
137 })
138 .collect(),
139 model_doc_link: model_doc_link.to_string(),
140 config_keys,
141 }
142 }
143
144 pub fn with_models(
145 name: &str,
146 display_name: &str,
147 description: &str,
148 default_model: &str,
149 models: Vec<ModelInfo>,
150 model_doc_link: &str,
151 config_keys: Vec<ConfigKey>,
152 ) -> Self {
153 Self {
154 name: name.to_string(),
155 display_name: display_name.to_string(),
156 description: description.to_string(),
157 default_model: default_model.to_string(),
158 known_models: models,
159 model_doc_link: model_doc_link.to_string(),
160 config_keys,
161 }
162 }
163
164 pub fn empty() -> Self {
165 Self {
166 name: "".to_string(),
167 display_name: "".to_string(),
168 description: "".to_string(),
169 default_model: "".to_string(),
170 known_models: vec![],
171 model_doc_link: "".to_string(),
172 config_keys: vec![],
173 }
174 }
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
179pub struct ConfigKey {
180 pub name: String,
182 pub required: bool,
184 pub secret: bool,
186 pub default: Option<String>,
188 pub oauth_flow: bool,
191}
192
193impl ConfigKey {
194 pub fn new(name: &str, required: bool, secret: bool, default: Option<&str>) -> Self {
196 Self {
197 name: name.to_string(),
198 required,
199 secret,
200 default: default.map(|s| s.to_string()),
201 oauth_flow: false,
202 }
203 }
204
205 pub fn from_value_type<T: ConfigValue>(required: bool, secret: bool) -> Self {
206 Self {
207 name: T::KEY.to_string(),
208 required,
209 secret,
210 default: Some(T::DEFAULT.to_string()),
211 oauth_flow: false,
212 }
213 }
214
215 pub fn new_oauth(name: &str, required: bool, secret: bool, default: Option<&str>) -> Self {
220 Self {
221 name: name.to_string(),
222 required,
223 secret,
224 default: default.map(|s| s.to_string()),
225 oauth_flow: true,
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct ProviderUsage {
232 pub model: String,
233 pub usage: Usage,
234}
235
236impl ProviderUsage {
237 pub fn new(model: String, usage: Usage) -> Self {
238 Self { model, usage }
239 }
240
241 pub async fn ensure_tokens(
243 &mut self,
244 system_prompt: &str,
245 request_messages: &[Message],
246 response: &Message,
247 tools: &[Tool],
248 ) -> Result<(), ProviderError> {
249 crate::providers::usage_estimator::ensure_usage_tokens(
250 self,
251 system_prompt,
252 request_messages,
253 response,
254 tools,
255 )
256 .await
257 .map_err(|e| ProviderError::ExecutionError(format!("Failed to ensure usage tokens: {}", e)))
258 }
259
260 pub fn combine_with(&self, other: &ProviderUsage) -> ProviderUsage {
263 ProviderUsage {
264 model: self.model.clone(),
265 usage: self.usage + other.usage,
266 }
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)]
271pub struct Usage {
272 pub input_tokens: Option<i32>,
273 pub output_tokens: Option<i32>,
274 pub total_tokens: Option<i32>,
275}
276
277fn sum_optionals<T>(a: Option<T>, b: Option<T>) -> Option<T>
278where
279 T: Add<Output = T> + Default,
280{
281 match (a, b) {
282 (Some(x), Some(y)) => Some(x + y),
283 (Some(x), None) => Some(x + T::default()),
284 (None, Some(y)) => Some(T::default() + y),
285 (None, None) => None,
286 }
287}
288
289impl Add for Usage {
290 type Output = Self;
291
292 fn add(self, other: Self) -> Self {
293 Self::new(
294 sum_optionals(self.input_tokens, other.input_tokens),
295 sum_optionals(self.output_tokens, other.output_tokens),
296 sum_optionals(self.total_tokens, other.total_tokens),
297 )
298 }
299}
300
301impl AddAssign for Usage {
302 fn add_assign(&mut self, rhs: Self) {
303 *self = *self + rhs;
304 }
305}
306
307impl Usage {
308 pub fn new(
309 input_tokens: Option<i32>,
310 output_tokens: Option<i32>,
311 total_tokens: Option<i32>,
312 ) -> Self {
313 let calculated_total = if total_tokens.is_none() {
314 match (input_tokens, output_tokens) {
315 (Some(input), Some(output)) => Some(input + output),
316 (Some(input), None) => Some(input),
317 (None, Some(output)) => Some(output),
318 (None, None) => None,
319 }
320 } else {
321 total_tokens
322 };
323
324 Self {
325 input_tokens,
326 output_tokens,
327 total_tokens: calculated_total,
328 }
329 }
330}
331
332use async_trait::async_trait;
333
334pub trait LeadWorkerProviderTrait {
336 fn get_model_info(&self) -> (String, String);
338
339 fn get_active_model(&self) -> String;
341
342 fn get_settings(&self) -> (usize, usize, usize);
344}
345
346#[async_trait]
348pub trait Provider: Send + Sync {
349 fn metadata() -> ProviderMetadata
351 where
352 Self: Sized;
353
354 fn get_name(&self) -> &str;
356
357 async fn complete_with_model(
360 &self,
361 model_config: &ModelConfig,
362 system: &str,
363 messages: &[Message],
364 tools: &[Tool],
365 ) -> Result<(Message, ProviderUsage), ProviderError>;
366
367 async fn complete(
369 &self,
370 system: &str,
371 messages: &[Message],
372 tools: &[Tool],
373 ) -> Result<(Message, ProviderUsage), ProviderError> {
374 let model_config = self.get_model_config();
375 self.complete_with_model(&model_config, system, messages, tools)
376 .await
377 }
378
379 async fn complete_fast(
381 &self,
382 system: &str,
383 messages: &[Message],
384 tools: &[Tool],
385 ) -> Result<(Message, ProviderUsage), ProviderError> {
386 let model_config = self.get_model_config();
387 let fast_config = model_config.use_fast_model();
388
389 match self
390 .complete_with_model(&fast_config, system, messages, tools)
391 .await
392 {
393 Ok(result) => Ok(result),
394 Err(e) => {
395 if fast_config.model_name != model_config.model_name {
396 tracing::warn!(
397 "Fast model {} failed with error: {}. Falling back to regular model {}",
398 fast_config.model_name,
399 e,
400 model_config.model_name
401 );
402 self.complete_with_model(&model_config, system, messages, tools)
403 .await
404 } else {
405 Err(e)
406 }
407 }
408 }
409 }
410
411 fn get_model_config(&self) -> ModelConfig;
413
414 fn retry_config(&self) -> RetryConfig {
415 RetryConfig::default()
416 }
417
418 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
419 Ok(None)
420 }
421
422 async fn fetch_recommended_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
424 let all_models = match self.fetch_supported_models().await? {
425 Some(models) => models,
426 None => return Ok(None),
427 };
428
429 let registry = CanonicalModelRegistry::bundled().map_err(|e| {
430 ProviderError::ExecutionError(format!("Failed to load canonical registry: {}", e))
431 })?;
432
433 let provider_name = self.get_name();
434
435 let recommended_models: Vec<String> = all_models
436 .iter()
437 .filter(|model| {
438 map_to_canonical_model(provider_name, model, registry)
439 .and_then(|canonical_id| registry.get(&canonical_id))
440 .map(|m| m.input_modalities.contains(&"text".to_string()))
441 .unwrap_or(false)
442 })
443 .cloned()
444 .collect();
445
446 if recommended_models.is_empty() {
447 Ok(Some(all_models))
448 } else {
449 Ok(Some(recommended_models))
450 }
451 }
452
453 async fn map_to_canonical_model(
454 &self,
455 provider_model: &str,
456 ) -> Result<Option<String>, ProviderError> {
457 let registry = CanonicalModelRegistry::bundled().map_err(|e| {
458 ProviderError::ExecutionError(format!("Failed to load canonical registry: {}", e))
459 })?;
460
461 Ok(map_to_canonical_model(
462 self.get_name(),
463 provider_model,
464 registry,
465 ))
466 }
467
468 fn supports_embeddings(&self) -> bool {
469 false
470 }
471
472 async fn supports_cache_control(&self) -> bool {
473 false
474 }
475
476 async fn create_embeddings(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
478 Err(ProviderError::ExecutionError(
479 "This provider does not support embeddings".to_string(),
480 ))
481 }
482
483 fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
486 None
487 }
488
489 async fn stream(
490 &self,
491 _system: &str,
492 _messages: &[Message],
493 _tools: &[Tool],
494 ) -> Result<MessageStream, ProviderError> {
495 Err(ProviderError::NotImplemented(
496 "streaming not implemented".to_string(),
497 ))
498 }
499
500 fn supports_streaming(&self) -> bool {
501 false
502 }
503
504 fn get_active_model_name(&self) -> String {
508 if let Some(lead_worker) = self.as_lead_worker() {
509 lead_worker.get_active_model()
510 } else {
511 self.get_model_config().model_name
512 }
513 }
514
515 fn get_initial_user_messages(&self, messages: &Conversation) -> Vec<String> {
517 messages
518 .iter()
519 .filter(|m| m.role == rmcp::model::Role::User)
520 .take(MSG_COUNT_FOR_SESSION_NAME_GENERATION)
521 .map(|m| m.as_concat_text())
522 .collect()
523 }
524
525 async fn generate_session_name(
528 &self,
529 messages: &Conversation,
530 ) -> Result<String, ProviderError> {
531 let context = self.get_initial_user_messages(messages);
532 let prompt = self.create_session_name_prompt(&context);
533 let message = Message::user().with_text(&prompt);
534 let result = self
535 .complete_fast(
536 "Reply with only a description in four words or less",
537 &[message],
538 &[],
539 )
540 .await?;
541
542 let description = result
543 .0
544 .as_concat_text()
545 .split_whitespace()
546 .collect::<Vec<_>>()
547 .join(" ");
548
549 Ok(safe_truncate(&description, 100))
550 }
551
552 fn create_session_name_prompt(&self, context: &[String]) -> String {
554 let mut prompt = "Based on the conversation so far, provide a concise description of this session in 4 words or less. This will be used for finding the session later in a UI with limited space - reply *ONLY* with the description".to_string();
556
557 if !context.is_empty() {
558 prompt = format!(
559 "Here are the first few user messages:\n{}\n\n{}",
560 context.join("\n"),
561 prompt
562 );
563 }
564 prompt
565 }
566
567 async fn configure_oauth(&self) -> Result<(), ProviderError> {
579 Err(ProviderError::ExecutionError(
580 "OAuth configuration not supported by this provider".to_string(),
581 ))
582 }
583}
584
585pub type MessageStream = Pin<
589 Box<dyn Stream<Item = Result<(Option<Message>, Option<ProviderUsage>), ProviderError>> + Send>,
590>;
591
592pub fn stream_from_single_message(message: Message, usage: ProviderUsage) -> MessageStream {
593 let stream = futures::stream::once(async move { Ok((Some(message), Some(usage))) });
594 Box::pin(stream)
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use std::collections::HashMap;
601
602 use serde_json::json;
603 #[test]
604 fn test_usage_creation() {
605 let usage = Usage::new(Some(10), Some(20), Some(30));
606 assert_eq!(usage.input_tokens, Some(10));
607 assert_eq!(usage.output_tokens, Some(20));
608 assert_eq!(usage.total_tokens, Some(30));
609 }
610
611 #[test]
612 fn test_usage_serialization() -> Result<()> {
613 let usage = Usage::new(Some(10), Some(20), Some(30));
614 let serialized = serde_json::to_string(&usage)?;
615 let deserialized: Usage = serde_json::from_str(&serialized)?;
616
617 assert_eq!(usage.input_tokens, deserialized.input_tokens);
618 assert_eq!(usage.output_tokens, deserialized.output_tokens);
619 assert_eq!(usage.total_tokens, deserialized.total_tokens);
620
621 let json_value: serde_json::Value = serde_json::from_str(&serialized)?;
623 assert_eq!(json_value["input_tokens"], json!(10));
624 assert_eq!(json_value["output_tokens"], json!(20));
625 assert_eq!(json_value["total_tokens"], json!(30));
626
627 Ok(())
628 }
629
630 #[test]
631 fn test_set_and_get_current_model() {
632 set_current_model("gpt-4o");
634
635 let model = get_current_model();
637 assert_eq!(model, Some("gpt-4o".to_string()));
638
639 set_current_model("claude-sonnet-4-20250514");
641
642 let model = get_current_model();
644 assert_eq!(model, Some("claude-sonnet-4-20250514".to_string()));
645 }
646
647 #[test]
648 fn test_provider_metadata_context_limits() {
649 let test_models = vec!["gpt-4o", "claude-sonnet-4-20250514", "unknown-model"];
651 let metadata = ProviderMetadata::new(
652 "test",
653 "Test Provider",
654 "Test Description",
655 "gpt-4o",
656 test_models,
657 "https://example.com",
658 vec![],
659 );
660
661 let model_info: HashMap<String, usize> = metadata
662 .known_models
663 .into_iter()
664 .map(|m| (m.name, m.context_limit))
665 .collect();
666
667 assert_eq!(*model_info.get("gpt-4o").unwrap(), 128_000);
669
670 assert_eq!(
672 *model_info.get("claude-sonnet-4-20250514").unwrap(),
673 200_000
674 );
675
676 assert_eq!(*model_info.get("unknown-model").unwrap(), 128_000);
678 }
679
680 #[test]
681 fn test_model_info_creation() {
682 let info = ModelInfo {
684 name: "test-model".to_string(),
685 context_limit: 1000,
686 input_token_cost: None,
687 output_token_cost: None,
688 currency: None,
689 supports_cache_control: None,
690 };
691 assert_eq!(info.context_limit, 1000);
692
693 let info2 = ModelInfo {
695 name: "test-model".to_string(),
696 context_limit: 1000,
697 input_token_cost: None,
698 output_token_cost: None,
699 currency: None,
700 supports_cache_control: None,
701 };
702 assert_eq!(info, info2);
703
704 let info3 = ModelInfo {
706 name: "test-model".to_string(),
707 context_limit: 2000,
708 input_token_cost: None,
709 output_token_cost: None,
710 currency: None,
711 supports_cache_control: None,
712 };
713 assert_ne!(info, info3);
714 }
715
716 #[test]
717 fn test_model_info_with_cost() {
718 let info = ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001);
719 assert_eq!(info.name, "gpt-4o");
720 assert_eq!(info.context_limit, 128000);
721 assert_eq!(info.input_token_cost, Some(0.0000025));
722 assert_eq!(info.output_token_cost, Some(0.00001));
723 assert_eq!(info.currency, Some("$".to_string()));
724 }
725}