alith_core/
llm.rs

1pub mod client;
2
3use crate::chat::{Completion, CompletionError};
4use crate::embeddings::{Embeddings, EmbeddingsData, EmbeddingsError};
5use anyhow::Result;
6use async_trait::async_trait;
7use client::{Client, CompletionResponse};
8
9#[cfg(feature = "inference")]
10use fastembed::TextEmbedding;
11#[cfg(feature = "inference")]
12pub use fastembed::{
13    EmbeddingModel as FastEmbeddingsModelName, ExecutionProviderDispatch,
14    InitOptions as FastEmbeddingsModelOptions,
15};
16#[cfg(feature = "inference")]
17use std::sync::Arc;
18
19// OpenAI models
20
21pub const GPT_4_5: &str = "gpt-4.5";
22pub const GPT_4: &str = "gpt-4";
23pub const GPT_4_32K: &str = "gpt-4-32k";
24pub const GPT_4_TURBO: &str = "gpt-4-turbo";
25pub const GPT_3_5_TURBO: &str = "gpt-3.5-turbo";
26pub const GPT_4O_MINI: &str = "gpt-4o-mini";
27
28// Anthropic models
29
30pub const CLAUDE_3_OPUS: &str = "claude-3-opus";
31pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet";
32pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku";
33pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet";
34pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet";
35
36// Remote Llama models
37
38pub const LLAMA_3_1_SONAR_SMALL_ONLINE: &str = "llama-3.1-sonar-small-128k-online";
39pub const LLAMA_3_1_SONAR_LARGE_ONLINE: &str = "llama-3.1-sonar-large-128k-online";
40pub const LLAMA_3_1_SONAR_HUGE_ONLINE: &str = "llama-3.1-sonar-huge-128k-online";
41pub const LLAMA_3_1_SONAR_SMALL_CHAT: &str = "llama-3.1-sonar-small-128k-chat";
42pub const LLAMA_3_1_SONAR_LARGE_CHAT: &str = "llama-3.1-sonar-large-128k-chat";
43pub const LLAMA_3_1_8B_INSTRUCT: &str = "llama-3.1-8b-instruct";
44pub const LLAMA_3_1_70B_INSTRUCT: &str = "llama-3.1-70b-instruct";
45
46// Remote Sonar models
47
48pub const SONAR_SMALL: &str = "sonar_small";
49pub const SONAR_LARGE: &str = "sonar_large";
50pub const SONAR_HUGE: &str = "sonar_huge";
51
52pub mod openai_compatible {
53    pub const DEEPSEEK_BASE_URL: &str = "https://api.deepseek.com";
54    pub const GROQ_BASE_URL: &str = "https://api.groq.com/openai/v1";
55    pub const HUNYUAN_BASE_URL: &str = "https://api.hunyuan.cloud.tencent.com/v1";
56    pub const MINIMAX_BASE_URL: &str = "https://api.minimax.chat/v1";
57    pub const MISTRAL_BASE_URL: &str = "https://api.mistral.ai/v1";
58    pub const MOONSHOT_BASE_URL: &str = "https://api.moonshot.cn/v1";
59    pub const PERPLEXITY_BASE_URL: &str = "https://api.perplexity.ai";
60    pub const QIANWEN_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1";
61}
62
63/// A struct representing a Large Language Model (LLM)
64pub struct LLM {
65    /// The name or identifier of the model to use
66    /// Examples: "gpt-4", "gpt-3.5-turbo", etc.
67    pub model: String,
68    /// The LLM client used to communicate with model backends
69    client: Client,
70}
71
72impl LLM {
73    pub fn from_model_name(model: &str) -> Result<Self> {
74        Ok(Self {
75            model: model.to_string(),
76            client: Client::from_model_name(model)?,
77        })
78    }
79
80    pub fn openai_compatible_model(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
81        Ok(Self {
82            model: model.to_string(),
83            client: Client::openai_compatible_client(api_key, base_url, model)?,
84        })
85    }
86
87    pub fn embeddings_model(&self, model: &str) -> EmbeddingsModel {
88        EmbeddingsModel {
89            model: model.to_string(),
90            client: self.client.clone(),
91        }
92    }
93
94    #[inline]
95    pub fn client(&self) -> &Client {
96        &self.client
97    }
98}
99
100impl Completion for LLM {
101    type Response = CompletionResponse;
102
103    async fn completion(
104        &mut self,
105        request: crate::chat::Request,
106    ) -> Result<Self::Response, CompletionError> {
107        self.client.completion(request).await
108    }
109}
110
111#[derive(Clone)]
112pub struct EmbeddingsModel {
113    pub client: Client,
114    pub model: String,
115}
116
117#[async_trait]
118impl Embeddings for EmbeddingsModel {
119    const MAX_DOCUMENTS: usize = 1024;
120
121    async fn embed_texts(
122        &self,
123        input: Vec<String>,
124    ) -> Result<Vec<EmbeddingsData>, EmbeddingsError> {
125        self.client.embed_texts(&self.model, input).await
126    }
127}
128
129#[cfg(feature = "inference")]
130#[derive(Clone)]
131pub struct FastEmbeddingsModel {
132    model: Arc<TextEmbedding>,
133}
134
135#[cfg(feature = "inference")]
136impl FastEmbeddingsModel {
137    /// Try to generate a new TextEmbedding Instance.
138    ///
139    /// Uses the highest level of Graph optimization.
140    ///
141    /// Uses the total number of CPUs available as the number of intra-threads.
142    pub fn try_new(opts: FastEmbeddingsModelOptions) -> anyhow::Result<Self> {
143        let model = TextEmbedding::try_new(opts)?;
144        Ok(Self {
145            model: Arc::new(model),
146        })
147    }
148
149    /// Try to generate a new TextEmbedding Instance.
150    ///
151    /// Uses the highest level of Graph optimization.
152    ///
153    /// Uses the total number of CPUs available as the number of intra-threads.
154    #[inline]
155    pub fn try_default() -> anyhow::Result<Self> {
156        Self::try_new(Default::default())
157    }
158}
159
160#[cfg(feature = "inference")]
161#[async_trait]
162impl Embeddings for FastEmbeddingsModel {
163    const MAX_DOCUMENTS: usize = 1024;
164
165    async fn embed_texts(
166        &self,
167        input: Vec<String>,
168    ) -> Result<Vec<EmbeddingsData>, EmbeddingsError> {
169        // Generate embeddings with the default batch size, 256
170        let embeddings = self
171            .model
172            .embed(input.clone(), None)
173            .map_err(|err| EmbeddingsError::ProviderError(err.to_string()))?;
174        Ok(input
175            .iter()
176            .zip(embeddings)
177            .map(|(doc, embeddings)| EmbeddingsData {
178                document: doc.to_string(),
179                vec: embeddings.iter().map(|e| *e as f64).collect(),
180            })
181            .collect())
182    }
183}