kernelx_core/providers/
mod.rs

1use crate::{capabilities::*, models::*, Error, Result};
2use async_trait::async_trait;
3use std::marker::PhantomData;
4
5mod local;
6mod remote;
7
8pub use remote::{IntoOpenAIChatMessage, OpenAI, OpenAIModels};
9
10pub struct ProviderBuilder<P> {
11    pub(crate) api_key: Option<String>,
12    pub(crate) api_base: Option<String>,
13    pub(crate) models: Option<Vec<ModelInfo>>,
14    pub(crate) config: Option<ModelConfig>,
15    pub(crate) _provider: PhantomData<P>,
16}
17
18impl<P> ProviderBuilder<P> {
19    pub fn new() -> Self {
20        Self {
21            api_key: None,
22            api_base: None,
23            models: None,
24            config: None,
25            _provider: PhantomData,
26        }
27    }
28
29    pub fn api_key(mut self, key: impl Into<String>) -> Self {
30        self.api_key = Some(key.into());
31        self
32    }
33
34    pub fn api_base(mut self, base: impl Into<String>) -> Self {
35        self.api_base = Some(base.into());
36        self
37    }
38
39    pub fn models(mut self, models: Vec<ModelInfo>) -> Self {
40        self.models = Some(models);
41        self
42    }
43
44    pub fn with_config(mut self, config: ModelConfig) -> Self {
45        self.config = Some(config);
46        self
47    }
48
49    pub fn build(self) -> Result<P>
50    where
51        P: Provider,
52    {
53        let api_key = self.api_key.as_deref().unwrap_or("").trim();
54        if api_key.is_empty() {
55            return Err(Error::Config("API key is required".into()));
56        }
57
58        P::from_builder(self)
59    }
60
61    pub fn take_models(&mut self) -> Option<Vec<ModelInfo>> {
62        self.models.take()
63    }
64
65    pub fn get_api_key(&self) -> Option<&str> {
66        self.api_key.as_deref()
67    }
68
69    pub fn get_api_base(&self) -> Option<&str> {
70        self.api_base.as_deref()
71    }
72
73    pub fn take_config(&mut self) -> Option<ModelConfig> {
74        self.config.take()
75    }
76}
77
78impl<P> Default for ProviderBuilder<P> {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84#[async_trait]
85pub trait Provider: HasCapability + Send + Sync + Clone + 'static {
86    fn models(&self) -> &[ModelInfo];
87
88    fn from_builder(builder: ProviderBuilder<Self>) -> Result<Self>;
89
90    fn get_model<C: ?Sized>(&self, model_id: impl AsRef<str>) -> Result<Model<Self, C>>
91    where
92        Self: Sized,
93    {
94        let model_id = model_id.as_ref();
95        let model_info = self
96            .models()
97            .iter()
98            .find(|m| m.id == model_id)
99            .ok_or_else(|| Error::ModelNotFound(model_id.to_string()))?;
100
101        Model::new(
102            self.clone(),
103            model_id.to_string(),
104            model_info.capabilities.clone(),
105        )
106    }
107}