Skip to main content

aster/providers/
base.rs

1use anyhow::Result;
2use futures::Stream;
3use serde::{Deserialize, Serialize};
4
5use super::canonical::{map_to_canonical_model, CanonicalModelRegistry};
6use super::errors::ProviderError;
7use super::retry::RetryConfig;
8use crate::config::base::ConfigValue;
9use crate::conversation::message::Message;
10use crate::conversation::Conversation;
11use crate::model::ModelConfig;
12use crate::utils::safe_truncate;
13use rmcp::model::Tool;
14use utoipa::ToSchema;
15
16use once_cell::sync::Lazy;
17use std::ops::{Add, AddAssign};
18use std::pin::Pin;
19use std::sync::Mutex;
20
21/// A global store for the current model being used, we use this as when a provider returns, it tells us the real model, not an alias
22pub static CURRENT_MODEL: Lazy<Mutex<Option<String>>> = Lazy::new(|| Mutex::new(None));
23
24/// Set the current model in the global store
25pub fn set_current_model(model: &str) {
26    if let Ok(mut current_model) = CURRENT_MODEL.lock() {
27        *current_model = Some(model.to_string());
28    }
29}
30
31/// Get the current model from the global store, the real model, not an alias
32pub fn get_current_model() -> Option<String> {
33    CURRENT_MODEL.lock().ok().and_then(|model| model.clone())
34}
35
36pub static MSG_COUNT_FOR_SESSION_NAME_GENERATION: usize = 3;
37
38/// Information about a model's capabilities
39#[derive(Debug, Clone, Serialize, Deserialize, ToSchema, PartialEq)]
40pub struct ModelInfo {
41    /// The name of the model
42    pub name: String,
43    /// The maximum context length this model supports
44    pub context_limit: usize,
45    /// Cost per token for input (optional)
46    pub input_token_cost: Option<f64>,
47    /// Cost per token for output (optional)
48    pub output_token_cost: Option<f64>,
49    /// Currency for the costs (default: "$")
50    pub currency: Option<String>,
51    /// Whether this model supports cache control
52    pub supports_cache_control: Option<bool>,
53}
54
55impl ModelInfo {
56    /// Create a new ModelInfo with just name and context limit
57    pub fn new(name: impl Into<String>, context_limit: usize) -> Self {
58        Self {
59            name: name.into(),
60            context_limit,
61            input_token_cost: None,
62            output_token_cost: None,
63            currency: None,
64            supports_cache_control: None,
65        }
66    }
67
68    /// Create a new ModelInfo with cost information (per token)
69    pub fn with_cost(
70        name: impl Into<String>,
71        context_limit: usize,
72        input_cost: f64,
73        output_cost: f64,
74    ) -> Self {
75        Self {
76            name: name.into(),
77            context_limit,
78            input_token_cost: Some(input_cost),
79            output_token_cost: Some(output_cost),
80            currency: Some("$".to_string()),
81            supports_cache_control: None,
82        }
83    }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
87pub enum ProviderType {
88    Preferred,
89    Builtin,
90    Declarative,
91    Custom,
92}
93
94/// Metadata about a provider's configuration requirements and capabilities
95#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
96pub struct ProviderMetadata {
97    /// The unique identifier for this provider
98    pub name: String,
99    /// Display name for the provider in UIs
100    pub display_name: String,
101    /// Description of the provider's capabilities
102    pub description: String,
103    /// The default/recommended model for this provider
104    pub default_model: String,
105    /// A list of currently known models with their capabilities
106    pub known_models: Vec<ModelInfo>,
107    /// Link to the docs where models can be found
108    pub model_doc_link: String,
109    /// Required configuration keys
110    pub config_keys: Vec<ConfigKey>,
111}
112
113impl ProviderMetadata {
114    pub fn new(
115        name: &str,
116        display_name: &str,
117        description: &str,
118        default_model: &str,
119        model_names: Vec<&str>,
120        model_doc_link: &str,
121        config_keys: Vec<ConfigKey>,
122    ) -> Self {
123        Self {
124            name: name.to_string(),
125            display_name: display_name.to_string(),
126            description: description.to_string(),
127            default_model: default_model.to_string(),
128            known_models: model_names
129                .iter()
130                .map(|&name| ModelInfo {
131                    name: name.to_string(),
132                    context_limit: ModelConfig::new_or_fail(name).context_limit(),
133                    input_token_cost: None,
134                    output_token_cost: None,
135                    currency: None,
136                    supports_cache_control: None,
137                })
138                .collect(),
139            model_doc_link: model_doc_link.to_string(),
140            config_keys,
141        }
142    }
143
144    pub fn with_models(
145        name: &str,
146        display_name: &str,
147        description: &str,
148        default_model: &str,
149        models: Vec<ModelInfo>,
150        model_doc_link: &str,
151        config_keys: Vec<ConfigKey>,
152    ) -> Self {
153        Self {
154            name: name.to_string(),
155            display_name: display_name.to_string(),
156            description: description.to_string(),
157            default_model: default_model.to_string(),
158            known_models: models,
159            model_doc_link: model_doc_link.to_string(),
160            config_keys,
161        }
162    }
163
164    pub fn empty() -> Self {
165        Self {
166            name: "".to_string(),
167            display_name: "".to_string(),
168            description: "".to_string(),
169            default_model: "".to_string(),
170            known_models: vec![],
171            model_doc_link: "".to_string(),
172            config_keys: vec![],
173        }
174    }
175}
176
177/// Configuration key metadata for provider setup
178#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
179pub struct ConfigKey {
180    /// The name of the configuration key (e.g., "API_KEY")
181    pub name: String,
182    /// Whether this key is required for the provider to function
183    pub required: bool,
184    /// Whether this key should be stored securely (e.g., in keychain)
185    pub secret: bool,
186    /// Optional default value for the key
187    pub default: Option<String>,
188    /// Whether this key should be configured using OAuth device code flow
189    /// When true, the provider's configure_oauth() method will be called instead of prompting for manual input
190    pub oauth_flow: bool,
191}
192
193impl ConfigKey {
194    /// Create a new ConfigKey
195    pub fn new(name: &str, required: bool, secret: bool, default: Option<&str>) -> Self {
196        Self {
197            name: name.to_string(),
198            required,
199            secret,
200            default: default.map(|s| s.to_string()),
201            oauth_flow: false,
202        }
203    }
204
205    pub fn from_value_type<T: ConfigValue>(required: bool, secret: bool) -> Self {
206        Self {
207            name: T::KEY.to_string(),
208            required,
209            secret,
210            default: Some(T::DEFAULT.to_string()),
211            oauth_flow: false,
212        }
213    }
214
215    /// Create a new ConfigKey that uses OAuth device code flow for configuration
216    ///
217    /// This is used for providers that support OAuth authentication instead of manual API key entry.
218    /// When oauth_flow is true, the configuration system will call the provider's configure_oauth() method.
219    pub fn new_oauth(name: &str, required: bool, secret: bool, default: Option<&str>) -> Self {
220        Self {
221            name: name.to_string(),
222            required,
223            secret,
224            default: default.map(|s| s.to_string()),
225            oauth_flow: true,
226        }
227    }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct ProviderUsage {
232    pub model: String,
233    pub usage: Usage,
234}
235
236impl ProviderUsage {
237    pub fn new(model: String, usage: Usage) -> Self {
238        Self { model, usage }
239    }
240
241    /// Ensures this ProviderUsage has token counts, estimating them if necessary
242    pub async fn ensure_tokens(
243        &mut self,
244        system_prompt: &str,
245        request_messages: &[Message],
246        response: &Message,
247        tools: &[Tool],
248    ) -> Result<(), ProviderError> {
249        crate::providers::usage_estimator::ensure_usage_tokens(
250            self,
251            system_prompt,
252            request_messages,
253            response,
254            tools,
255        )
256        .await
257        .map_err(|e| ProviderError::ExecutionError(format!("Failed to ensure usage tokens: {}", e)))
258    }
259
260    /// Combine this ProviderUsage with another, adding their token counts
261    /// Uses the model from this ProviderUsage
262    pub fn combine_with(&self, other: &ProviderUsage) -> ProviderUsage {
263        ProviderUsage {
264            model: self.model.clone(),
265            usage: self.usage + other.usage,
266        }
267    }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize, Default, Copy)]
271pub struct Usage {
272    pub input_tokens: Option<i32>,
273    pub output_tokens: Option<i32>,
274    pub total_tokens: Option<i32>,
275}
276
277fn sum_optionals<T>(a: Option<T>, b: Option<T>) -> Option<T>
278where
279    T: Add<Output = T> + Default,
280{
281    match (a, b) {
282        (Some(x), Some(y)) => Some(x + y),
283        (Some(x), None) => Some(x + T::default()),
284        (None, Some(y)) => Some(T::default() + y),
285        (None, None) => None,
286    }
287}
288
289impl Add for Usage {
290    type Output = Self;
291
292    fn add(self, other: Self) -> Self {
293        Self::new(
294            sum_optionals(self.input_tokens, other.input_tokens),
295            sum_optionals(self.output_tokens, other.output_tokens),
296            sum_optionals(self.total_tokens, other.total_tokens),
297        )
298    }
299}
300
301impl AddAssign for Usage {
302    fn add_assign(&mut self, rhs: Self) {
303        *self = *self + rhs;
304    }
305}
306
307impl Usage {
308    pub fn new(
309        input_tokens: Option<i32>,
310        output_tokens: Option<i32>,
311        total_tokens: Option<i32>,
312    ) -> Self {
313        let calculated_total = if total_tokens.is_none() {
314            match (input_tokens, output_tokens) {
315                (Some(input), Some(output)) => Some(input + output),
316                (Some(input), None) => Some(input),
317                (None, Some(output)) => Some(output),
318                (None, None) => None,
319            }
320        } else {
321            total_tokens
322        };
323
324        Self {
325            input_tokens,
326            output_tokens,
327            total_tokens: calculated_total,
328        }
329    }
330}
331
332use async_trait::async_trait;
333
334/// Trait for LeadWorkerProvider-specific functionality
335pub trait LeadWorkerProviderTrait {
336    /// Get information about the lead and worker models for logging
337    fn get_model_info(&self) -> (String, String);
338
339    /// Get the currently active model name
340    fn get_active_model(&self) -> String;
341
342    /// Get (lead_turns, failure_threshold, fallback_turns)
343    fn get_settings(&self) -> (usize, usize, usize);
344}
345
346/// Base trait for AI providers (OpenAI, Anthropic, etc)
347#[async_trait]
348pub trait Provider: Send + Sync {
349    /// Get the metadata for this provider type
350    fn metadata() -> ProviderMetadata
351    where
352        Self: Sized;
353
354    /// Get the name of this provider instance
355    fn get_name(&self) -> &str;
356
357    // Internal implementation of complete, used by complete_fast and complete
358    // Providers should override this to implement their actual completion logic
359    async fn complete_with_model(
360        &self,
361        model_config: &ModelConfig,
362        system: &str,
363        messages: &[Message],
364        tools: &[Tool],
365    ) -> Result<(Message, ProviderUsage), ProviderError>;
366
367    // Default implementation: use the provider's configured model
368    async fn complete(
369        &self,
370        system: &str,
371        messages: &[Message],
372        tools: &[Tool],
373    ) -> Result<(Message, ProviderUsage), ProviderError> {
374        let model_config = self.get_model_config();
375        self.complete_with_model(&model_config, system, messages, tools)
376            .await
377    }
378
379    // Check if a fast model is configured, otherwise fall back to regular model
380    async fn complete_fast(
381        &self,
382        system: &str,
383        messages: &[Message],
384        tools: &[Tool],
385    ) -> Result<(Message, ProviderUsage), ProviderError> {
386        let model_config = self.get_model_config();
387        let fast_config = model_config.use_fast_model();
388
389        match self
390            .complete_with_model(&fast_config, system, messages, tools)
391            .await
392        {
393            Ok(result) => Ok(result),
394            Err(e) => {
395                if fast_config.model_name != model_config.model_name {
396                    tracing::warn!(
397                        "Fast model {} failed with error: {}. Falling back to regular model {}",
398                        fast_config.model_name,
399                        e,
400                        model_config.model_name
401                    );
402                    self.complete_with_model(&model_config, system, messages, tools)
403                        .await
404                } else {
405                    Err(e)
406                }
407            }
408        }
409    }
410
411    /// Get the model config from the provider
412    fn get_model_config(&self) -> ModelConfig;
413
414    fn retry_config(&self) -> RetryConfig {
415        RetryConfig::default()
416    }
417
418    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
419        Ok(None)
420    }
421
422    /// Fetch models filtered by canonical registry and usability
423    async fn fetch_recommended_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
424        let all_models = match self.fetch_supported_models().await? {
425            Some(models) => models,
426            None => return Ok(None),
427        };
428
429        let registry = CanonicalModelRegistry::bundled().map_err(|e| {
430            ProviderError::ExecutionError(format!("Failed to load canonical registry: {}", e))
431        })?;
432
433        let provider_name = self.get_name();
434
435        let recommended_models: Vec<String> = all_models
436            .iter()
437            .filter(|model| {
438                map_to_canonical_model(provider_name, model, registry)
439                    .and_then(|canonical_id| registry.get(&canonical_id))
440                    .map(|m| m.input_modalities.contains(&"text".to_string()))
441                    .unwrap_or(false)
442            })
443            .cloned()
444            .collect();
445
446        if recommended_models.is_empty() {
447            Ok(Some(all_models))
448        } else {
449            Ok(Some(recommended_models))
450        }
451    }
452
453    async fn map_to_canonical_model(
454        &self,
455        provider_model: &str,
456    ) -> Result<Option<String>, ProviderError> {
457        let registry = CanonicalModelRegistry::bundled().map_err(|e| {
458            ProviderError::ExecutionError(format!("Failed to load canonical registry: {}", e))
459        })?;
460
461        Ok(map_to_canonical_model(
462            self.get_name(),
463            provider_model,
464            registry,
465        ))
466    }
467
468    fn supports_embeddings(&self) -> bool {
469        false
470    }
471
472    async fn supports_cache_control(&self) -> bool {
473        false
474    }
475
476    /// Create embeddings if supported. Default implementation returns an error.
477    async fn create_embeddings(&self, _texts: Vec<String>) -> Result<Vec<Vec<f32>>, ProviderError> {
478        Err(ProviderError::ExecutionError(
479            "This provider does not support embeddings".to_string(),
480        ))
481    }
482
483    /// Check if this provider is a LeadWorkerProvider
484    /// This is used for logging model information at startup
485    fn as_lead_worker(&self) -> Option<&dyn LeadWorkerProviderTrait> {
486        None
487    }
488
489    async fn stream(
490        &self,
491        _system: &str,
492        _messages: &[Message],
493        _tools: &[Tool],
494    ) -> Result<MessageStream, ProviderError> {
495        Err(ProviderError::NotImplemented(
496            "streaming not implemented".to_string(),
497        ))
498    }
499
500    fn supports_streaming(&self) -> bool {
501        false
502    }
503
504    /// Get the currently active model name
505    /// For regular providers, this returns the configured model
506    /// For LeadWorkerProvider, this returns the currently active model (lead or worker)
507    fn get_active_model_name(&self) -> String {
508        if let Some(lead_worker) = self.as_lead_worker() {
509            lead_worker.get_active_model()
510        } else {
511            self.get_model_config().model_name
512        }
513    }
514
515    /// Returns the first 3 user messages as strings for session naming
516    fn get_initial_user_messages(&self, messages: &Conversation) -> Vec<String> {
517        messages
518            .iter()
519            .filter(|m| m.role == rmcp::model::Role::User)
520            .take(MSG_COUNT_FOR_SESSION_NAME_GENERATION)
521            .map(|m| m.as_concat_text())
522            .collect()
523    }
524
525    /// Generate a session name/description based on the conversation history
526    /// Creates a prompt asking for a concise description in 4 words or less.
527    async fn generate_session_name(
528        &self,
529        messages: &Conversation,
530    ) -> Result<String, ProviderError> {
531        let context = self.get_initial_user_messages(messages);
532        let prompt = self.create_session_name_prompt(&context);
533        let message = Message::user().with_text(&prompt);
534        let result = self
535            .complete_fast(
536                "Reply with only a description in four words or less",
537                &[message],
538                &[],
539            )
540            .await?;
541
542        let description = result
543            .0
544            .as_concat_text()
545            .split_whitespace()
546            .collect::<Vec<_>>()
547            .join(" ");
548
549        Ok(safe_truncate(&description, 100))
550    }
551
552    // Generate a prompt for a session name based on the conversation history
553    fn create_session_name_prompt(&self, context: &[String]) -> String {
554        // Create a prompt for a concise description
555        let mut prompt = "Based on the conversation so far, provide a concise description of this session in 4 words or less. This will be used for finding the session later in a UI with limited space - reply *ONLY* with the description".to_string();
556
557        if !context.is_empty() {
558            prompt = format!(
559                "Here are the first few user messages:\n{}\n\n{}",
560                context.join("\n"),
561                prompt
562            );
563        }
564        prompt
565    }
566
567    /// Configure OAuth authentication for this provider
568    ///
569    /// This method is called when a provider has configuration keys marked with oauth_flow = true.
570    /// Providers that support OAuth should override this method to implement their specific OAuth flow.
571    ///
572    /// # Returns
573    /// * `Ok(())` if OAuth configuration succeeds and credentials are saved
574    /// * `Err(ProviderError)` if OAuth fails or is not supported by this provider
575    ///
576    /// # Default Implementation
577    /// The default implementation returns an error indicating OAuth is not supported.
578    async fn configure_oauth(&self) -> Result<(), ProviderError> {
579        Err(ProviderError::ExecutionError(
580            "OAuth configuration not supported by this provider".to_string(),
581        ))
582    }
583}
584
585/// A message stream yields partial text content but complete tool calls, all within the Message object
586/// So a message with text will contain potentially just a word of a longer response, but tool calls
587/// messages will only be yielded once concatenated.
588pub type MessageStream = Pin<
589    Box<dyn Stream<Item = Result<(Option<Message>, Option<ProviderUsage>), ProviderError>> + Send>,
590>;
591
592pub fn stream_from_single_message(message: Message, usage: ProviderUsage) -> MessageStream {
593    let stream = futures::stream::once(async move { Ok((Some(message), Some(usage))) });
594    Box::pin(stream)
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600    use std::collections::HashMap;
601
602    use serde_json::json;
603    #[test]
604    fn test_usage_creation() {
605        let usage = Usage::new(Some(10), Some(20), Some(30));
606        assert_eq!(usage.input_tokens, Some(10));
607        assert_eq!(usage.output_tokens, Some(20));
608        assert_eq!(usage.total_tokens, Some(30));
609    }
610
611    #[test]
612    fn test_usage_serialization() -> Result<()> {
613        let usage = Usage::new(Some(10), Some(20), Some(30));
614        let serialized = serde_json::to_string(&usage)?;
615        let deserialized: Usage = serde_json::from_str(&serialized)?;
616
617        assert_eq!(usage.input_tokens, deserialized.input_tokens);
618        assert_eq!(usage.output_tokens, deserialized.output_tokens);
619        assert_eq!(usage.total_tokens, deserialized.total_tokens);
620
621        // Test JSON structure
622        let json_value: serde_json::Value = serde_json::from_str(&serialized)?;
623        assert_eq!(json_value["input_tokens"], json!(10));
624        assert_eq!(json_value["output_tokens"], json!(20));
625        assert_eq!(json_value["total_tokens"], json!(30));
626
627        Ok(())
628    }
629
630    #[test]
631    fn test_set_and_get_current_model() {
632        // Set the model
633        set_current_model("gpt-4o");
634
635        // Get the model and verify
636        let model = get_current_model();
637        assert_eq!(model, Some("gpt-4o".to_string()));
638
639        // Change the model
640        set_current_model("claude-sonnet-4-20250514");
641
642        // Get the updated model and verify
643        let model = get_current_model();
644        assert_eq!(model, Some("claude-sonnet-4-20250514".to_string()));
645    }
646
647    #[test]
648    fn test_provider_metadata_context_limits() {
649        // Test that ProviderMetadata::new correctly sets context limits
650        let test_models = vec!["gpt-4o", "claude-sonnet-4-20250514", "unknown-model"];
651        let metadata = ProviderMetadata::new(
652            "test",
653            "Test Provider",
654            "Test Description",
655            "gpt-4o",
656            test_models,
657            "https://example.com",
658            vec![],
659        );
660
661        let model_info: HashMap<String, usize> = metadata
662            .known_models
663            .into_iter()
664            .map(|m| (m.name, m.context_limit))
665            .collect();
666
667        // gpt-4o should have 128k limit
668        assert_eq!(*model_info.get("gpt-4o").unwrap(), 128_000);
669
670        // claude-sonnet-4-20250514 should have 200k limit
671        assert_eq!(
672            *model_info.get("claude-sonnet-4-20250514").unwrap(),
673            200_000
674        );
675
676        // unknown model should have default limit (128k)
677        assert_eq!(*model_info.get("unknown-model").unwrap(), 128_000);
678    }
679
680    #[test]
681    fn test_model_info_creation() {
682        // Test direct ModelInfo creation
683        let info = ModelInfo {
684            name: "test-model".to_string(),
685            context_limit: 1000,
686            input_token_cost: None,
687            output_token_cost: None,
688            currency: None,
689            supports_cache_control: None,
690        };
691        assert_eq!(info.context_limit, 1000);
692
693        // Test equality
694        let info2 = ModelInfo {
695            name: "test-model".to_string(),
696            context_limit: 1000,
697            input_token_cost: None,
698            output_token_cost: None,
699            currency: None,
700            supports_cache_control: None,
701        };
702        assert_eq!(info, info2);
703
704        // Test inequality
705        let info3 = ModelInfo {
706            name: "test-model".to_string(),
707            context_limit: 2000,
708            input_token_cost: None,
709            output_token_cost: None,
710            currency: None,
711            supports_cache_control: None,
712        };
713        assert_ne!(info, info3);
714    }
715
716    #[test]
717    fn test_model_info_with_cost() {
718        let info = ModelInfo::with_cost("gpt-4o", 128000, 0.0000025, 0.00001);
719        assert_eq!(info.name, "gpt-4o");
720        assert_eq!(info.context_limit, 128000);
721        assert_eq!(info.input_token_cost, Some(0.0000025));
722        assert_eq!(info.output_token_cost, Some(0.00001));
723        assert_eq!(info.currency, Some("$".to_string()));
724    }
725}