Skip to main content

aster/providers/
tetrate.rs

1use super::api_client::{ApiClient, AuthMethod};
2use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
3use super::errors::ProviderError;
4use super::retry::ProviderRetry;
5use super::utils::{
6    get_model, handle_response_google_compat, handle_response_openai_compat,
7    handle_status_openai_compat, is_google_model, stream_openai_compat, RequestLog,
8};
9use crate::config::signup_tetrate::TETRATE_DEFAULT_MODEL;
10use crate::conversation::message::Message;
11use anyhow::Result;
12use async_trait::async_trait;
13use serde_json::Value;
14
15use crate::model::ModelConfig;
16use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
17use rmcp::model::Tool;
18
19// Tetrate Agent Router Service can run many models, we suggest the default
20pub const TETRATE_KNOWN_MODELS: &[&str] = &[
21    "claude-opus-4-1",
22    "claude-3-7-sonnet-latest",
23    "claude-sonnet-4-20250514",
24    "gemini-2.5-pro",
25    "gemini-2.0-flash",
26    "gemini-2.0-flash-lite",
27    "gpt-5",
28    "gpt-5-mini",
29    "gpt-5-nano",
30    "gpt-4.1",
31];
32pub const TETRATE_DOC_URL: &str = "https://router.tetrate.ai";
33
34#[derive(serde::Serialize)]
35pub struct TetrateProvider {
36    #[serde(skip)]
37    api_client: ApiClient,
38    model: ModelConfig,
39    supports_streaming: bool,
40    #[serde(skip)]
41    name: String,
42}
43
44impl TetrateProvider {
45    pub async fn from_env(model: ModelConfig) -> Result<Self> {
46        let config = crate::config::Config::global();
47        let api_key: String = config.get_secret("TETRATE_API_KEY")?;
48        // API host for LLM endpoints (/v1/chat/completions, /v1/models)
49        let host: String = config
50            .get_param("TETRATE_HOST")
51            .unwrap_or_else(|_| "https://api.router.tetrate.ai".to_string());
52
53        let auth = AuthMethod::BearerToken(api_key);
54        let api_client = ApiClient::new(host, auth)?
55            .with_header("HTTP-Referer", "https://astercloud.github.io/aster-rust")?
56            .with_header("X-Title", "aster")?;
57
58        Ok(Self {
59            api_client,
60            model,
61            supports_streaming: true,
62            name: Self::metadata().name,
63        })
64    }
65
66    async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
67        let response = self
68            .api_client
69            .response_post("v1/chat/completions", payload)
70            .await?;
71
72        // Handle Google-compatible model responses differently
73        if is_google_model(payload) {
74            return handle_response_google_compat(response).await;
75        }
76
77        // For OpenAI-compatible models, parse the response body to JSON
78        let response_body = handle_response_openai_compat(response)
79            .await
80            .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?;
81
82        let _debug = format!(
83            "Tetrate Agent Router Service request with payload: {} and response: {}",
84            serde_json::to_string_pretty(payload).unwrap_or_else(|_| "Invalid JSON".to_string()),
85            serde_json::to_string_pretty(&response_body)
86                .unwrap_or_else(|_| "Invalid JSON".to_string())
87        );
88
89        // Tetrate Agent Router Service can return errors in 200 OK responses, so we have to check for errors explicitly
90        if let Some(error_obj) = response_body.get("error") {
91            // If there's an error object, extract the error message and code
92            let error_message = error_obj
93                .get("message")
94                .and_then(|m| m.as_str())
95                .unwrap_or("Unknown Tetrate Agent Router Service error");
96
97            let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0);
98
99            // Check for context length errors in the error message
100            if error_code == 400 && error_message.contains("maximum context length") {
101                return Err(ProviderError::ContextLengthExceeded(
102                    error_message.to_string(),
103                ));
104            }
105
106            // Return appropriate error based on the error code
107            match error_code {
108                401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())),
109                429 => {
110                    return Err(ProviderError::RateLimitExceeded {
111                        details: error_message.to_string(),
112                        retry_delay: None,
113                    })
114                }
115                500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())),
116                _ => return Err(ProviderError::RequestFailed(error_message.to_string())),
117            }
118        }
119
120        // No error detected, return the response body
121        Ok(response_body)
122    }
123}
124
125#[async_trait]
126impl Provider for TetrateProvider {
127    fn metadata() -> ProviderMetadata {
128        ProviderMetadata::new(
129            "tetrate",
130            "Tetrate Agent Router Service",
131            "Enterprise router for AI models",
132            TETRATE_DEFAULT_MODEL,
133            TETRATE_KNOWN_MODELS.to_vec(),
134            TETRATE_DOC_URL,
135            vec![
136                ConfigKey::new("TETRATE_API_KEY", true, true, None),
137                ConfigKey::new(
138                    "TETRATE_HOST",
139                    false,
140                    false,
141                    Some("https://api.router.tetrate.ai"),
142                ),
143            ],
144        )
145    }
146
147    fn get_name(&self) -> &str {
148        &self.name
149    }
150
151    fn get_model_config(&self) -> ModelConfig {
152        self.model.clone()
153    }
154
155    #[tracing::instrument(
156        skip(self, model_config, system, messages, tools),
157        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
158    )]
159    async fn complete_with_model(
160        &self,
161        model_config: &ModelConfig,
162        system: &str,
163        messages: &[Message],
164        tools: &[Tool],
165    ) -> Result<(Message, ProviderUsage), ProviderError> {
166        let payload = create_request(
167            model_config,
168            system,
169            messages,
170            tools,
171            &super::utils::ImageFormat::OpenAi,
172            false,
173        )?;
174        let mut log = RequestLog::start(model_config, &payload)?;
175
176        // Make request
177        let response = self
178            .with_retry(|| async {
179                let payload_clone = payload.clone();
180                self.post(&payload_clone).await
181            })
182            .await?;
183
184        // Parse response
185        let message = response_to_message(&response)?;
186        let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
187            tracing::debug!("Failed to get usage data");
188            Usage::default()
189        });
190        let model = get_model(&response);
191        log.write(&response, Some(&usage))?;
192        Ok((message, ProviderUsage::new(model, usage)))
193    }
194
195    async fn stream(
196        &self,
197        system: &str,
198        messages: &[Message],
199        tools: &[Tool],
200    ) -> Result<MessageStream, ProviderError> {
201        let payload = create_request(
202            &self.model,
203            system,
204            messages,
205            tools,
206            &super::utils::ImageFormat::OpenAi,
207            true,
208        )?;
209
210        let mut log = RequestLog::start(&self.model, &payload)?;
211
212        let response = self
213            .with_retry(|| async {
214                let resp = self
215                    .api_client
216                    .response_post("v1/chat/completions", &payload)
217                    .await?;
218                handle_status_openai_compat(resp).await
219            })
220            .await
221            .inspect_err(|e| {
222                let _ = log.error(e);
223            })?;
224
225        stream_openai_compat(response, log)
226    }
227
228    /// Fetch supported models from Tetrate Agent Router Service API (only models with tool support)
229    async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
230        // Use the existing api_client which already has authentication configured
231        let response = match self.api_client.response_get("v1/models").await {
232            Ok(response) => response,
233            Err(e) => {
234                tracing::warn!("Failed to fetch models from Tetrate Agent Router Service API: {}, falling back to manual model entry", e);
235                return Ok(None);
236            }
237        };
238
239        // Handle JSON parsing failures gracefully
240        let json: serde_json::Value = match response.json().await {
241            Ok(json) => json,
242            Err(e) => {
243                tracing::warn!("Failed to parse Tetrate Agent Router Service API response as JSON: {}, falling back to manual model entry", e);
244                return Ok(None);
245            }
246        };
247
248        // Check for error in response
249        if let Some(err_obj) = json.get("error") {
250            let msg = err_obj
251                .get("message")
252                .and_then(|v| v.as_str())
253                .unwrap_or("unknown error");
254            tracing::warn!(
255                "Tetrate Agent Router Service API returned an error: {}",
256                msg
257            );
258            return Ok(None);
259        }
260
261        // The response format from /v1/models is expected to be OpenAI-compatible
262        // It should have a "data" field with an array of model objects
263        let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
264            ProviderError::UsageError("Missing data field in JSON response".into())
265        })?;
266
267        let mut models: Vec<String> = data
268            .iter()
269            .filter_map(|model| {
270                // Get the model ID
271                let id = model.get("id").and_then(|v| v.as_str())?;
272
273                // Check if the model supports computer_use (which indicates tool/function support)
274                // The Tetrate API uses "supports_computer_use" instead of "supported_parameters"
275                let supports_computer_use = model
276                    .get("supports_computer_use")
277                    .and_then(|v| v.as_bool())
278                    .unwrap_or(false);
279
280                if supports_computer_use {
281                    Some(id.to_string())
282                } else {
283                    tracing::debug!(
284                        "Model '{}' does not support computer_use (tool support), skipping",
285                        id
286                    );
287                    None
288                }
289            })
290            .collect();
291
292        // If no models with tool support were found, fall back to manual entry
293        if models.is_empty() {
294            tracing::warn!("No models with tool support found in Tetrate Agent Router Service API response, falling back to manual model entry");
295            return Ok(None);
296        }
297
298        models.sort();
299        Ok(Some(models))
300    }
301
302    fn supports_streaming(&self) -> bool {
303        self.supports_streaming
304    }
305}