1use std::ops::Deref;
2use std::ops::DerefMut;
3use std::sync::Arc;
4
5use crate::chat::CallFunction;
6use crate::chat::Completion;
7use crate::chat::CompletionError;
8use crate::chat::Request;
9use crate::chat::ResponseContent;
10use crate::chat::ResponseToolCalls;
11use crate::chat::ToolCall;
12use crate::embeddings::EmbeddingsData;
13use crate::embeddings::EmbeddingsError;
14use anyhow::Result;
15
16pub use alith_client as client;
17pub use alith_client::LLMClient;
18pub use alith_client::basic_completion::BasicCompletion;
19pub use alith_client::embeddings::Embeddings;
20pub use alith_client::prelude::*;
21pub use alith_interface::requests::completion::{CompletionRequest, CompletionResponse};
22pub use alith_models::api_model::ApiLLMModel;
23
24impl ResponseContent for CompletionResponse {
25 fn content(&self) -> String {
26 self.content.to_string()
27 }
28}
29
30pub struct Client {
31 pub(crate) client: LLMClient,
32}
33
34impl Deref for Client {
35 type Target = LLMClient;
36
37 fn deref(&self) -> &Self::Target {
38 &self.client
39 }
40}
41
42impl DerefMut for Client {
43 fn deref_mut(&mut self) -> &mut Self::Target {
44 &mut self.client
45 }
46}
47
48impl Clone for Client {
49 fn clone(&self) -> Self {
50 Self {
51 client: LLMClient::new(Arc::clone(&self.client.backend)),
52 }
53 }
54}
55
56impl Client {
57 pub fn from_model_name(model: &str) -> Result<Client> {
58 if model.starts_with("gpt") {
59 let mut builder = LLMClient::openai();
60 builder.model = ApiLLMModel::openai_model_from_model_id(model);
61 let client = builder.init()?;
62 Ok(Client { client })
63 } else if model.starts_with("claude") {
64 let mut builder = LLMClient::anthropic();
65 builder.model = ApiLLMModel::anthropic_model_from_model_id(model);
66 let client = builder.init()?;
67 Ok(Client { client })
68 } else if model.starts_with("llama") || model.starts_with("sonar") {
69 let mut builder = LLMClient::perplexity();
70 builder.model = ApiLLMModel::perplexity_model_from_model_id(model);
71 let client = builder.init()?;
72 Ok(Client { client })
73 } else {
74 Err(anyhow::anyhow!("unknown model {model}"))
75 }
76 }
77
78 pub fn openai_compatible_client(api_key: &str, base_url: &str, model: &str) -> Result<Client> {
79 let mut builder = LLMClient::openai();
80 builder.model = ApiLLMModel::gpt_4();
81 builder.model.model_base.model_id = model.to_string();
82 builder.config.api_config.api_key = Some(api_key.to_string().into());
83 builder.config.api_config.host = base_url.to_string();
84 builder.config.logging_config.logger_name = "generic".to_string();
85 let client = builder.init()?;
86 Ok(Client { client })
87 }
88}
89
90impl ResponseToolCalls for CompletionResponse {
91 fn toolcalls(&self) -> Vec<ToolCall> {
92 self.tool_calls
93 .as_ref()
94 .unwrap_or(&Vec::new())
95 .iter()
96 .map(|call| ToolCall {
97 id: call.id.clone(),
98 r#type: call.r#type.clone(),
99 function: CallFunction {
100 name: call.function.name.clone(),
101 arguments: call.function.arguments.clone(),
102 },
103 })
104 .collect()
105 }
106}
107
108impl Drop for Client {
109 fn drop(&mut self) {
110 self.client.shutdown();
111 }
112}
113
114impl Completion for Client {
115 type Response = CompletionResponse;
116
117 async fn completion(&mut self, request: Request) -> Result<Self::Response, CompletionError> {
118 let mut completion = self.client.basic_completion();
120 if let Some(temperature) = request.temperature {
121 completion.temperature(temperature);
122 }
123 if let Some(max_tokens) = request.max_tokens {
124 completion.max_tokens(max_tokens.try_into().unwrap());
125 }
126 let prompt = completion.prompt();
128 if !request.preamble.trim().is_empty() {
130 prompt
131 .add_system_message()
132 .map_err(|err| CompletionError::Normal(err.to_string()))?
133 .set_content(&request.preamble);
134 }
135 for msg in &request.history {
137 let result = match msg.role.as_str() {
138 "system" => prompt.add_system_message(),
139 "user" => prompt.add_user_message(),
140 "assistant" => prompt.add_assistant_message(),
141 _ => continue, };
143 result
144 .map_err(|err| CompletionError::Normal(err.to_string()))?
145 .set_content(&msg.content);
146 }
147 let mut input = request.prompt.clone();
148 for knowledge in &request.knowledges {
150 input.push('\n');
151 input.push_str(knowledge);
152 }
153 if request.documents.is_empty() {
155 prompt
156 .add_user_message()
157 .map_err(|err| CompletionError::Normal(err.to_string()))?
158 .set_content(&input);
159 } else {
160 prompt
161 .add_user_message()
162 .map_err(|err| CompletionError::Normal(err.to_string()))?
163 .set_content(request.prompt_with_context(input));
164 }
165 completion.base_req.tools.append(&mut request.tools.clone());
167 completion
169 .run()
170 .await
171 .map_err(|err| CompletionError::Normal(err.to_string()))
172 }
173}
174
175impl Client {
176 pub async fn embed_texts(
177 &self,
178 model: &str,
179 input: Vec<String>,
180 ) -> Result<Vec<EmbeddingsData>, EmbeddingsError> {
181 let mut embeddings = self.client.embeddings();
182 embeddings.set_input(input.clone());
183 embeddings.set_model(model.to_string());
184 embeddings
185 .run()
186 .await
187 .map(|resp| {
188 resp.data
189 .iter()
190 .zip(input)
191 .map(|(data, document)| EmbeddingsData {
192 document,
193 vec: data.embedding.clone(),
194 })
195 .collect()
196 })
197 .map_err(|err| EmbeddingsError::ResponseError(err.to_string()))
198 }
199}