alith_core/llm/
client.rs

1use std::ops::Deref;
2use std::ops::DerefMut;
3use std::sync::Arc;
4
5use crate::chat::CallFunction;
6use crate::chat::Completion;
7use crate::chat::CompletionError;
8use crate::chat::Request;
9use crate::chat::ResponseContent;
10use crate::chat::ResponseToolCalls;
11use crate::chat::ToolCall;
12use crate::embeddings::EmbeddingsData;
13use crate::embeddings::EmbeddingsError;
14use anyhow::Result;
15
16pub use alith_client as client;
17pub use alith_client::LLMClient;
18pub use alith_client::basic_completion::BasicCompletion;
19pub use alith_client::embeddings::Embeddings;
20pub use alith_client::prelude::*;
21pub use alith_interface::requests::completion::{CompletionRequest, CompletionResponse};
22pub use alith_models::api_model::ApiLLMModel;
23
24impl ResponseContent for CompletionResponse {
25    fn content(&self) -> String {
26        self.content.to_string()
27    }
28}
29
30pub struct Client {
31    pub(crate) client: LLMClient,
32}
33
34impl Deref for Client {
35    type Target = LLMClient;
36
37    fn deref(&self) -> &Self::Target {
38        &self.client
39    }
40}
41
42impl DerefMut for Client {
43    fn deref_mut(&mut self) -> &mut Self::Target {
44        &mut self.client
45    }
46}
47
48impl Clone for Client {
49    fn clone(&self) -> Self {
50        Self {
51            client: LLMClient::new(Arc::clone(&self.client.backend)),
52        }
53    }
54}
55
56impl Client {
57    pub fn from_model_name(model: &str) -> Result<Client> {
58        if model.starts_with("gpt") {
59            let mut builder = LLMClient::openai();
60            builder.model = ApiLLMModel::openai_model_from_model_id(model);
61            let client = builder.init()?;
62            Ok(Client { client })
63        } else if model.starts_with("claude") {
64            let mut builder = LLMClient::anthropic();
65            builder.model = ApiLLMModel::anthropic_model_from_model_id(model);
66            let client = builder.init()?;
67            Ok(Client { client })
68        } else if model.starts_with("llama") || model.starts_with("sonar") {
69            let mut builder = LLMClient::perplexity();
70            builder.model = ApiLLMModel::perplexity_model_from_model_id(model);
71            let client = builder.init()?;
72            Ok(Client { client })
73        } else {
74            Err(anyhow::anyhow!("unknown model {model}"))
75        }
76    }
77
78    pub fn openai_compatible_client(api_key: &str, base_url: &str, model: &str) -> Result<Client> {
79        let mut builder = LLMClient::openai();
80        builder.model = ApiLLMModel::gpt_4();
81        builder.model.model_base.model_id = model.to_string();
82        builder.config.api_config.api_key = Some(api_key.to_string().into());
83        builder.config.api_config.host = base_url.to_string();
84        builder.config.logging_config.logger_name = "generic".to_string();
85        let client = builder.init()?;
86        Ok(Client { client })
87    }
88}
89
90impl ResponseToolCalls for CompletionResponse {
91    fn toolcalls(&self) -> Vec<ToolCall> {
92        self.tool_calls
93            .as_ref()
94            .unwrap_or(&Vec::new())
95            .iter()
96            .map(|call| ToolCall {
97                id: call.id.clone(),
98                r#type: call.r#type.clone(),
99                function: CallFunction {
100                    name: call.function.name.clone(),
101                    arguments: call.function.arguments.clone(),
102                },
103            })
104            .collect()
105    }
106}
107
108impl Drop for Client {
109    fn drop(&mut self) {
110        self.client.shutdown();
111    }
112}
113
114impl Completion for Client {
115    type Response = CompletionResponse;
116
117    async fn completion(&mut self, request: Request) -> Result<Self::Response, CompletionError> {
118        // New the complation request
119        let mut completion = self.client.basic_completion();
120        if let Some(temperature) = request.temperature {
121            completion.temperature(temperature);
122        }
123        if let Some(max_tokens) = request.max_tokens {
124            completion.max_tokens(max_tokens.try_into().unwrap());
125        }
126        // Construct the prompt
127        let prompt = completion.prompt();
128        // Add preamble if provided
129        if !request.preamble.trim().is_empty() {
130            prompt
131                .add_system_message()
132                .map_err(|err| CompletionError::Normal(err.to_string()))?
133                .set_content(&request.preamble);
134        }
135        // Add conversation history
136        for msg in &request.history {
137            let result = match msg.role.as_str() {
138                "system" => prompt.add_system_message(),
139                "user" => prompt.add_user_message(),
140                "assistant" => prompt.add_assistant_message(),
141                _ => continue, // Just skip unknown roles
142            };
143            result
144                .map_err(|err| CompletionError::Normal(err.to_string()))?
145                .set_content(&msg.content);
146        }
147        let mut input = request.prompt.clone();
148        // Add knowledge sources if provided
149        for knowledge in &request.knowledges {
150            input.push('\n');
151            input.push_str(knowledge);
152        }
153        // Add user prompt with or without the document context
154        if request.documents.is_empty() {
155            prompt
156                .add_user_message()
157                .map_err(|err| CompletionError::Normal(err.to_string()))?
158                .set_content(&input);
159        } else {
160            prompt
161                .add_user_message()
162                .map_err(|err| CompletionError::Normal(err.to_string()))?
163                .set_content(request.prompt_with_context(input));
164        }
165        // Add custom tools
166        completion.base_req.tools.append(&mut request.tools.clone());
167        // Execute the completion request
168        completion
169            .run()
170            .await
171            .map_err(|err| CompletionError::Normal(err.to_string()))
172    }
173}
174
175impl Client {
176    pub async fn embed_texts(
177        &self,
178        model: &str,
179        input: Vec<String>,
180    ) -> Result<Vec<EmbeddingsData>, EmbeddingsError> {
181        let mut embeddings = self.client.embeddings();
182        embeddings.set_input(input.clone());
183        embeddings.set_model(model.to_string());
184        embeddings
185            .run()
186            .await
187            .map(|resp| {
188                resp.data
189                    .iter()
190                    .zip(input)
191                    .map(|(data, document)| EmbeddingsData {
192                        document,
193                        vec: data.embedding.clone(),
194                    })
195                    .collect()
196            })
197            .map_err(|err| EmbeddingsError::ResponseError(err.to_string()))
198    }
199}