1pub mod client;
2
3use crate::chat::{Completion, CompletionError};
4use crate::embeddings::{Embeddings, EmbeddingsData, EmbeddingsError};
5use anyhow::Result;
6use async_trait::async_trait;
7use client::{Client, CompletionResponse};
8
9#[cfg(feature = "inference")]
10use fastembed::TextEmbedding;
11#[cfg(feature = "inference")]
12pub use fastembed::{
13 EmbeddingModel as FastEmbeddingsModelName, ExecutionProviderDispatch,
14 InitOptions as FastEmbeddingsModelOptions,
15};
16#[cfg(feature = "inference")]
17use std::sync::Arc;
18
19pub const GPT_4_5: &str = "gpt-4.5";
22pub const GPT_4: &str = "gpt-4";
23pub const GPT_4_32K: &str = "gpt-4-32k";
24pub const GPT_4_TURBO: &str = "gpt-4-turbo";
25pub const GPT_3_5_TURBO: &str = "gpt-3.5-turbo";
26pub const GPT_4O_MINI: &str = "gpt-4o-mini";
27
28pub const CLAUDE_3_OPUS: &str = "claude-3-opus";
31pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet";
32pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku";
33pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet";
34pub const CLAUDE_3_7_SONNET: &str = "claude-3-7-sonnet";
35
36pub const LLAMA_3_1_SONAR_SMALL_ONLINE: &str = "llama-3.1-sonar-small-128k-online";
39pub const LLAMA_3_1_SONAR_LARGE_ONLINE: &str = "llama-3.1-sonar-large-128k-online";
40pub const LLAMA_3_1_SONAR_HUGE_ONLINE: &str = "llama-3.1-sonar-huge-128k-online";
41pub const LLAMA_3_1_SONAR_SMALL_CHAT: &str = "llama-3.1-sonar-small-128k-chat";
42pub const LLAMA_3_1_SONAR_LARGE_CHAT: &str = "llama-3.1-sonar-large-128k-chat";
43pub const LLAMA_3_1_8B_INSTRUCT: &str = "llama-3.1-8b-instruct";
44pub const LLAMA_3_1_70B_INSTRUCT: &str = "llama-3.1-70b-instruct";
45
46pub const SONAR_SMALL: &str = "sonar_small";
49pub const SONAR_LARGE: &str = "sonar_large";
50pub const SONAR_HUGE: &str = "sonar_huge";
51
52pub mod openai_compatible {
53 pub const DEEPSEEK_BASE_URL: &str = "https://api.deepseek.com";
54 pub const GROQ_BASE_URL: &str = "https://api.groq.com/openai/v1";
55 pub const HUNYUAN_BASE_URL: &str = "https://api.hunyuan.cloud.tencent.com/v1";
56 pub const MINIMAX_BASE_URL: &str = "https://api.minimax.chat/v1";
57 pub const MISTRAL_BASE_URL: &str = "https://api.mistral.ai/v1";
58 pub const MOONSHOT_BASE_URL: &str = "https://api.moonshot.cn/v1";
59 pub const PERPLEXITY_BASE_URL: &str = "https://api.perplexity.ai";
60 pub const QIANWEN_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1";
61}
62
63pub struct LLM {
65 pub model: String,
68 client: Client,
70}
71
72impl LLM {
73 pub fn from_model_name(model: &str) -> Result<Self> {
74 Ok(Self {
75 model: model.to_string(),
76 client: Client::from_model_name(model)?,
77 })
78 }
79
80 pub fn openai_compatible_model(api_key: &str, base_url: &str, model: &str) -> Result<Self> {
81 Ok(Self {
82 model: model.to_string(),
83 client: Client::openai_compatible_client(api_key, base_url, model)?,
84 })
85 }
86
87 pub fn embeddings_model(&self, model: &str) -> EmbeddingsModel {
88 EmbeddingsModel {
89 model: model.to_string(),
90 client: self.client.clone(),
91 }
92 }
93
94 #[inline]
95 pub fn client(&self) -> &Client {
96 &self.client
97 }
98}
99
100impl Completion for LLM {
101 type Response = CompletionResponse;
102
103 async fn completion(
104 &mut self,
105 request: crate::chat::Request,
106 ) -> Result<Self::Response, CompletionError> {
107 self.client.completion(request).await
108 }
109}
110
111#[derive(Clone)]
112pub struct EmbeddingsModel {
113 pub client: Client,
114 pub model: String,
115}
116
117#[async_trait]
118impl Embeddings for EmbeddingsModel {
119 const MAX_DOCUMENTS: usize = 1024;
120
121 async fn embed_texts(
122 &self,
123 input: Vec<String>,
124 ) -> Result<Vec<EmbeddingsData>, EmbeddingsError> {
125 self.client.embed_texts(&self.model, input).await
126 }
127}
128
129#[cfg(feature = "inference")]
130#[derive(Clone)]
131pub struct FastEmbeddingsModel {
132 model: Arc<TextEmbedding>,
133}
134
135#[cfg(feature = "inference")]
136impl FastEmbeddingsModel {
137 pub fn try_new(opts: FastEmbeddingsModelOptions) -> anyhow::Result<Self> {
143 let model = TextEmbedding::try_new(opts)?;
144 Ok(Self {
145 model: Arc::new(model),
146 })
147 }
148
149 #[inline]
155 pub fn try_default() -> anyhow::Result<Self> {
156 Self::try_new(Default::default())
157 }
158}
159
160#[cfg(feature = "inference")]
161#[async_trait]
162impl Embeddings for FastEmbeddingsModel {
163 const MAX_DOCUMENTS: usize = 1024;
164
165 async fn embed_texts(
166 &self,
167 input: Vec<String>,
168 ) -> Result<Vec<EmbeddingsData>, EmbeddingsError> {
169 let embeddings = self
171 .model
172 .embed(input.clone(), None)
173 .map_err(|err| EmbeddingsError::ProviderError(err.to_string()))?;
174 Ok(input
175 .iter()
176 .zip(embeddings)
177 .map(|(doc, embeddings)| EmbeddingsData {
178 document: doc.to_string(),
179 vec: embeddings.iter().map(|e| *e as f64).collect(),
180 })
181 .collect())
182 }
183}