Skip to main content

albert_api/
client.rs

1use crate::error::ApiError;
2use crate::sse::SseParser;
3use crate::types::*;
4use std::collections::VecDeque;
5use std::time::Duration;
6
7const DEFAULT_BASE_URL: &str = "https://api.ternlang.com";
8const REQUEST_ID_HEADER: &str = "x-request-id";
9const ALT_REQUEST_ID_HEADER: &str = "request-id";
10const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(500);
11const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
14pub enum LlmProvider {
15    Ternlang,
16    Anthropic,
17    OpenAi,
18    HuggingFace,
19    Google,
20    Azure,
21    Aws,
22    Ollama,
23    Xai,
24}
25
26impl LlmProvider {
27    pub fn default_base_url(&self) -> &'static str {
28        match self {
29            Self::Ternlang => "https://api.ternlang.com",
30            Self::Anthropic => "https://api.anthropic.com",
31            Self::OpenAi => "https://api.openai.com",
32            Self::HuggingFace => "https://api-inference.huggingface.co",
33            Self::Google => "https://generativelanguage.googleapis.com",
34            Self::Azure => "https://api.azure.com",
35            Self::Aws => "https://bedrock-runtime.us-east-1.amazonaws.com",
36            Self::Ollama => "http://localhost:11434",
37            Self::Xai => "https://api.x.ai",
38        }
39    }
40
41    pub fn api_path(&self) -> &'static str {
42        match self {
43            Self::Ternlang => "/v1/messages",
44            Self::Anthropic => "/v1/messages",
45            Self::OpenAi => "/v1/chat/completions",
46            Self::HuggingFace => "/models",
47            Self::Google => "/v1beta",
48            Self::Ollama => "/v1/chat/completions",
49            Self::Xai => "/v1/chat/completions",
50            _ => "/v1/messages",
51        }
52    }
53}
54
55#[derive(Clone)]
56pub struct TernlangClient {
57    pub provider: LlmProvider,
58    pub base_url: String,
59    pub auth: AuthSource,
60    pub http: reqwest::Client,
61    pub max_retries: u32,
62    pub initial_backoff: Duration,
63    pub max_backoff: Duration,
64}
65
66impl TernlangClient {
67    pub fn from_auth(auth: AuthSource) -> Self {
68        Self {
69            provider: LlmProvider::Ternlang,
70            base_url: DEFAULT_BASE_URL.to_string(),
71            auth,
72            http: reqwest::Client::new(),
73            max_retries: 3,
74            initial_backoff: DEFAULT_INITIAL_BACKOFF,
75            max_backoff: DEFAULT_MAX_BACKOFF,
76        }
77    }
78
79    pub fn from_env() -> Result<Self, ApiError> {
80        Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
81    }
82
83    #[must_use]
84    pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
85        self.auth = auth;
86        self
87    }
88
89    #[must_use]
90    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
91        self.base_url = base_url.into();
92        self
93    }
94
95    #[must_use]
96    pub fn with_provider(mut self, provider: LlmProvider) -> Self {
97        self.provider = provider;
98        if self.base_url == DEFAULT_BASE_URL {
99            self.base_url = provider.default_base_url().to_string();
100        }
101        self
102    }
103
104    async fn send_raw_request(
105        &self,
106        request: &MessageRequest,
107    ) -> Result<reqwest::Response, ApiError> {
108        let path = self.provider.api_path();
109        let mut request_url = format!("{}/{}", self.base_url.trim_end_matches('/'), path.trim_start_matches('/'));
110        
111        let body = match self.provider {
112            LlmProvider::Google => {
113                let model_id = if request.model.starts_with("models/") {
114                    request.model.clone()
115                } else {
116                    format!("models/{}", request.model)
117                };
118                request_url = format!("{}/{}:generateContent", self.base_url.trim_end_matches('/'), model_id);
119                translate_to_gemini(request)
120            }
121            LlmProvider::Anthropic => translate_to_anthropic(request),
122            LlmProvider::OpenAi | LlmProvider::Ollama | LlmProvider::Xai => translate_to_openai(request),
123            _ => serde_json::to_value(request).map_err(ApiError::from)?,
124        };
125
126        let mut request_builder = self
127            .http
128            .post(&request_url)
129            .header("content-type", "application/json");
130
131        if self.provider == LlmProvider::Anthropic {
132            request_builder = request_builder.header("anthropic-version", "2023-06-01");
133        }
134
135        let request_builder = self.auth.apply(self.provider, request_builder);
136
137        request_builder.json(&body).send().await.map_err(ApiError::from)
138    }
139
140    pub async fn send_message(
141        &self,
142        request: &MessageRequest,
143    ) -> Result<MessageResponse, ApiError> {
144        let request = MessageRequest {
145            stream: false,
146            ..request.clone()
147        };
148        let response = self.send_with_retry(&request).await?;
149        let request_id = request_id_from_headers(response.headers());
150        let response_json = response
151            .json::<serde_json::Value>()
152            .await
153            .map_err(ApiError::from)?;
154        
155        let mut final_response = match self.provider {
156            LlmProvider::Google => translate_from_gemini(response_json, &request.model),
157            LlmProvider::Anthropic => translate_from_anthropic(response_json, &request.model),
158            LlmProvider::OpenAi | LlmProvider::Ollama | LlmProvider::Xai => translate_from_openai(response_json, &request.model),
159            _ => serde_json::from_value::<MessageResponse>(response_json).map_err(ApiError::from)?,
160        };
161
162        if final_response.request_id.is_none() {
163            final_response.request_id = request_id;
164        }
165        Ok(final_response)
166    }
167
168    pub async fn stream_message(
169        &mut self,
170        request: &MessageRequest,
171    ) -> Result<MessageStream, ApiError> {
172        let response = self
173            .send_with_retry(&request.clone().with_streaming())
174            .await?;
175        Ok(MessageStream {
176            _request_id: request_id_from_headers(response.headers()),
177            response,
178            parser: SseParser::new(),
179            pending: VecDeque::new(),
180            done: false,
181        })
182    }
183
184    async fn send_with_retry(
185        &self,
186        request: &MessageRequest,
187    ) -> Result<reqwest::Response, ApiError> {
188        let mut attempts = 0;
189        let mut last_error: Option<ApiError>;
190
191        loop {
192            attempts += 1;
193            match self.send_raw_request(request).await {
194                Ok(response) => match expect_success(response).await {
195                    Ok(response) => return Ok(response),
196                    Err(error) if error.is_retryable() && attempts <= self.max_retries => {
197                        last_error = Some(error);
198                    }
199                    Err(error) => return Err(error),
200                },
201                Err(error) if error.is_retryable() && attempts <= self.max_retries => {
202                    last_error = Some(error);
203                }
204                Err(error) => return Err(error),
205            }
206
207            if attempts > self.max_retries {
208                break;
209            }
210
211            tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
212        }
213
214        Err(ApiError::RetriesExhausted {
215            attempts,
216            last_error: Box::new(last_error.unwrap_or(ApiError::Auth("Max retries exceeded without error capture".to_string()))),
217        })
218    }
219
220    fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
221        let multiplier = 2_u32.pow(attempt.saturating_sub(1));
222        Ok(self
223            .initial_backoff
224            .checked_mul(multiplier)
225            .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
226    }
227
228    pub async fn list_remote_models(&self) -> Result<Vec<String>, ApiError> {
229        match self.provider {
230            LlmProvider::Google => {
231                let url = format!("{}/v1beta/models?key={}", self.base_url.trim_end_matches('/'), self.auth.api_key().unwrap_or(""));
232                let res = self.http.get(&url).send().await.map_err(ApiError::from)?;
233                let json: serde_json::Value = res.json().await.map_err(ApiError::from)?;
234                
235                let mut models = vec![];
236                if let Some(list) = json.get("models").and_then(|m| m.as_array()) {
237                    for m in list {
238                        if let Some(name) = m.get("name").and_then(|n| n.as_str()) {
239                            models.push(name.replace("models/", ""));
240                        }
241                    }
242                }
243                Ok(models)
244            }
245            LlmProvider::OpenAi | LlmProvider::Ollama | LlmProvider::Xai => {
246                let url = format!("{}/v1/models", self.base_url.trim_end_matches('/'));
247                let res = self.auth.apply(self.provider, self.http.get(&url)).send().await.map_err(ApiError::from)?;
248                let json: serde_json::Value = res.json().await.map_err(ApiError::from)?;
249                
250                let mut models = vec![];
251                if let Some(list) = json.get("data").and_then(|m| m.as_array()) {
252                    for m in list {
253                        if let Some(id) = m.get("id").and_then(|i| i.as_str()) {
254                            models.push(id.to_string());
255                        }
256                    }
257                }
258                Ok(models)
259            }
260            _ => Ok(vec![])
261        }
262    }
263
264    pub async fn exchange_oauth_code(
265        &self,
266        _config: OAuthConfig,
267        _request: &OAuthTokenExchangeRequest,
268    ) -> Result<RuntimeTokenSet, ApiError> {
269        Ok(RuntimeTokenSet {
270            access_token: "dummy_token".to_string(),
271            refresh_token: None,
272            expires_at: None,
273            scopes: vec![],
274        })
275    }
276}
277
278#[derive(Debug)]
279pub struct MessageStream {
280    _request_id: Option<String>,
281    response: reqwest::Response,
282    parser: SseParser,
283    pending: VecDeque<StreamEvent>,
284    done: bool,
285}
286
287impl MessageStream {
288    pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
289        loop {
290            if let Some(event) = self.pending.pop_front() {
291                return Ok(Some(event));
292            }
293            if self.done { return Ok(None); }
294            let chunk = self.response.chunk().await?.ok_or_else(|| {
295                self.done = true;
296                ApiError::Auth("stream closed".to_string())
297            })?;
298            self.pending.extend(self.parser.push(&chunk)?);
299        }
300    }
301}
302
303fn translate_to_anthropic(request: &MessageRequest) -> serde_json::Value {
304    use serde_json::json;
305    let messages: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
306        let content: Vec<serde_json::Value> = msg.content.iter().map(|block| {
307            match block {
308                InputContentBlock::Text { text } => json!({ "type": "text", "text": text }),
309                InputContentBlock::ToolUse { id, name, input } => json!({
310                    "type": "tool_use", "id": id, "name": name, "input": input
311                }),
312                InputContentBlock::ToolResult { tool_use_id, content, is_error } => {
313                    let text = content.iter().filter_map(|c| {
314                        if let ToolResultContentBlock::Text { text } = c { Some(text.clone()) } else { None }
315                    }).collect::<Vec<String>>().join("\n");
316                    json!({
317                        "type": "tool_result", "tool_use_id": tool_use_id, "content": text, "is_error": is_error
318                    })
319                }
320            }
321        }).collect();
322        json!({ "role": msg.role, "content": content })
323    }).collect();
324
325    let mut body = json!({
326        "model": request.model,
327        "messages": messages,
328        "max_tokens": request.max_tokens.unwrap_or(4096),
329        "stream": request.stream
330    });
331    if let Some(system) = &request.system { body["system"] = json!(system); }
332    if let Some(tools) = &request.tools {
333        body["tools"] = json!(tools.iter().map(|t| {
334            json!({ "name": t.name, "description": t.description, "input_schema": t.input_schema })
335        }).collect::<Vec<_>>());
336    }
337    body
338}
339
340fn translate_to_openai(request: &MessageRequest) -> serde_json::Value {
341    use serde_json::json;
342    let mut messages = vec![];
343    if let Some(system) = &request.system { messages.push(json!({ "role": "system", "content": system })); }
344
345    for msg in &request.messages {
346        let mut content_text = String::new();
347        let mut tool_calls = vec![];
348
349        for block in &msg.content {
350            match block {
351                InputContentBlock::Text { text } => content_text.push_str(text),
352                InputContentBlock::ToolUse { id, name, input } => {
353                    tool_calls.push(json!({
354                        "id": id, "type": "function", "function": { "name": name, "arguments": input.to_string() }
355                    }));
356                }
357                InputContentBlock::ToolResult { tool_use_id, content, .. } => {
358                    let text = content.iter().filter_map(|c| {
359                        if let ToolResultContentBlock::Text { text } = c { Some(text.clone()) } else { None }
360                    }).collect::<Vec<String>>().join("\n");
361                    messages.push(json!({ "role": "tool", "tool_call_id": tool_use_id, "content": text }));
362                }
363            }
364        }
365
366        if !content_text.is_empty() || !tool_calls.is_empty() {
367            let mut m = json!({ "role": msg.role });
368            if !content_text.is_empty() { m["content"] = json!(content_text); }
369            if !tool_calls.is_empty() { m["tool_calls"] = json!(tool_calls); }
370            messages.push(m);
371        }
372    }
373
374    let mut body = json!({ "model": request.model, "messages": messages, "stream": request.stream });
375    if let Some(tools) = &request.tools {
376        body["tools"] = json!(tools.iter().map(|t| {
377            json!({ "type": "function", "function": { "name": t.name, "description": t.description, "parameters": t.input_schema } })
378        }).collect::<Vec<_>>());
379    }
380    body
381}
382
383fn translate_to_gemini(request: &MessageRequest) -> serde_json::Value {
384    use serde_json::json;
385    let contents: Vec<serde_json::Value> = request.messages.iter().map(|msg| {
386        let role = if msg.role == "assistant" { "model" } else { "user" };
387        let parts: Vec<serde_json::Value> = msg.content.iter().map(|block| {
388            match block {
389                InputContentBlock::Text { text } => json!({ "text": text }),
390                InputContentBlock::ToolUse { name, input, .. } => json!({ "functionCall": { "name": name, "args": input } }),
391                InputContentBlock::ToolResult { tool_use_id, content, .. } => {
392                    let text = content.iter().filter_map(|c| {
393                        if let ToolResultContentBlock::Text { text } = c { Some(text.clone()) } else { None }
394                    }).collect::<Vec<String>>().join("\n");
395                    json!({ "functionResponse": { "name": tool_use_id, "response": { "result": text } } })
396                }
397            }
398        }).collect();
399        json!({ "role": role, "parts": parts })
400    }).collect();
401
402    let mut body = json!({ "contents": contents });
403    if let Some(system) = &request.system { body["systemInstruction"] = json!({ "parts": [{ "text": system }] }); }
404    if let Some(tools) = &request.tools {
405        let declarations: Vec<serde_json::Value> = tools.iter().map(|t| {
406            json!({ "name": t.name, "description": t.description, "parameters": t.input_schema })
407        }).collect();
408        body["tools"] = json!([{ "functionDeclarations": declarations }]);
409    }
410    body
411}
412
413fn translate_from_anthropic(response: serde_json::Value, model: &str) -> MessageResponse {
414    let mut content = vec![];
415    if let Some(blocks) = response.get("content").and_then(|c| c.as_array()) {
416        for block in blocks {
417            match block.get("type").and_then(|t| t.as_str()) {
418                Some("text") => if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
419                    content.push(OutputContentBlock::Text { text: text.to_string() });
420                },
421                Some("tool_use") => if let (Some(id), Some(name), Some(input)) = (
422                    block.get("id").and_then(|i| i.as_str()),
423                    block.get("name").and_then(|n| n.as_str()),
424                    block.get("input")
425                ) {
426                    content.push(OutputContentBlock::ToolUse { id: id.to_string(), name: name.to_string(), input: input.clone() });
427                },
428                _ => {}
429            }
430        }
431    }
432    let mut usage = Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: 0 };
433    if let Some(u) = response.get("usage") {
434        usage.input_tokens = u.get("input_tokens").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
435        usage.output_tokens = u.get("output_tokens").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
436    }
437    MessageResponse {
438        id: response.get("id").and_then(|i| i.as_str()).unwrap_or("anthropic-response").to_string(),
439        kind: "message".to_string(), role: "assistant".to_string(), content, model: model.to_string(),
440        stop_reason: response.get("stop_reason").and_then(|s| s.as_str()).map(|s| s.to_string()),
441        stop_sequence: None, usage, request_id: None,
442    }
443}
444
445fn translate_from_openai(response: serde_json::Value, model: &str) -> MessageResponse {
446    let mut content = vec![];
447    if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
448        if let Some(choice) = choices.first() {
449            if let Some(message) = choice.get("message") {
450                if let Some(text) = message.get("content").and_then(|c| c.as_str()) {
451                    content.push(OutputContentBlock::Text { text: text.to_string() });
452                }
453                if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) {
454                    for call in tool_calls {
455                        if let (Some(id), Some(name), Some(args_str)) = (
456                            call.get("id").and_then(|i| i.as_str()),
457                            call.get("function").and_then(|f| f.get("name")).and_then(|n| n.as_str()),
458                            call.get("function").and_then(|f| f.get("arguments")).and_then(|a| a.as_str())
459                        ) {
460                            if let Ok(args) = serde_json::from_str(args_str) {
461                                content.push(OutputContentBlock::ToolUse { id: id.to_string(), name: name.to_string(), input: args });
462                            }
463                        }
464                    }
465                }
466            }
467        }
468    }
469    let mut usage = Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: 0 };
470    if let Some(u) = response.get("usage") {
471        usage.input_tokens = u.get("prompt_tokens").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
472        usage.output_tokens = u.get("completion_tokens").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
473    }
474    MessageResponse {
475        id: response.get("id").and_then(|i| i.as_str()).unwrap_or("openai-response").to_string(),
476        kind: "message".to_string(), role: "assistant".to_string(), content, model: model.to_string(),
477        stop_reason: Some("end_turn".to_string()), stop_sequence: None, usage, request_id: None,
478    }
479}
480
481fn translate_from_gemini(response: serde_json::Value, model: &str) -> MessageResponse {
482    let mut content = vec![];
483    if let Some(candidates) = response.get("candidates").and_then(|c| c.as_array()) {
484        if let Some(candidate) = candidates.first() {
485            if let Some(parts) = candidate.get("content").and_then(|c| c.get("parts")).and_then(|p| p.as_array()) {
486                for part in parts {
487                    if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
488                        content.push(OutputContentBlock::Text { text: text.to_string() });
489                    }
490                    if let Some(call) = part.get("functionCall") {
491                        if let (Some(name), Some(args)) = (call.get("name").and_then(|n| n.as_str()), call.get("args")) {
492                            content.push(OutputContentBlock::ToolUse { id: name.to_string(), name: name.to_string(), input: args.clone() });
493                        }
494                    }
495                }
496            }
497        }
498    }
499    let mut usage = Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: 0 };
500    if let Some(u) = response.get("usageMetadata") {
501        usage.input_tokens = u.get("promptTokenCount").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
502        usage.output_tokens = u.get("candidatesTokenCount").and_then(|c| c.as_u64()).unwrap_or(0) as u32;
503    }
504    MessageResponse {
505        id: "gemini-response".to_string(), kind: "message".to_string(), role: "assistant".to_string(),
506        content, model: model.to_string(), stop_reason: Some("end_turn".to_string()),
507        stop_sequence: None, usage, request_id: None,
508    }
509}
510
511pub fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
512    match std::env::var(key) {
513        Ok(value) if !value.is_empty() => Ok(Some(value)),
514        Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
515        Err(error) => Err(ApiError::from(error)),
516    }
517}
518
519pub fn read_base_url() -> String {
520    std::env::var("TERNLANG_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
521}
522
523fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
524    headers
525        .get(REQUEST_ID_HEADER)
526        .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
527        .and_then(|value| value.to_str().ok())
528        .map(ToOwned::to_owned)
529}
530
531async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
532    if response.status().is_success() {
533        Ok(response)
534    } else {
535        Err(ApiError::Auth(format!("HTTP {}", response.status())))
536    }
537}
538
539pub fn resolve_startup_auth_source() -> Result<AuthSource, ApiError> {
540    if let Some(api_key) = read_env_non_empty("TERNLANG_API_KEY")? {
541        return Ok(AuthSource::ApiKey(api_key));
542    }
543    Ok(AuthSource::None)
544}
545
546#[derive(serde::Deserialize)]
547pub struct OAuthConfig {}