1use super::api_client::{ApiClient, AuthMethod};
2use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, ProviderUsage, Usage};
3use super::errors::ProviderError;
4use super::retry::ProviderRetry;
5use super::utils::{
6 get_model, handle_response_google_compat, handle_response_openai_compat,
7 handle_status_openai_compat, is_google_model, stream_openai_compat, RequestLog,
8};
9use crate::config::signup_tetrate::TETRATE_DEFAULT_MODEL;
10use crate::conversation::message::Message;
11use anyhow::Result;
12use async_trait::async_trait;
13use serde_json::Value;
14
15use crate::model::ModelConfig;
16use crate::providers::formats::openai::{create_request, get_usage, response_to_message};
17use rmcp::model::Tool;
18
19pub const TETRATE_KNOWN_MODELS: &[&str] = &[
21 "claude-opus-4-1",
22 "claude-3-7-sonnet-latest",
23 "claude-sonnet-4-20250514",
24 "gemini-2.5-pro",
25 "gemini-2.0-flash",
26 "gemini-2.0-flash-lite",
27 "gpt-5",
28 "gpt-5-mini",
29 "gpt-5-nano",
30 "gpt-4.1",
31];
32pub const TETRATE_DOC_URL: &str = "https://router.tetrate.ai";
33
34#[derive(serde::Serialize)]
35pub struct TetrateProvider {
36 #[serde(skip)]
37 api_client: ApiClient,
38 model: ModelConfig,
39 supports_streaming: bool,
40 #[serde(skip)]
41 name: String,
42}
43
44impl TetrateProvider {
45 pub async fn from_env(model: ModelConfig) -> Result<Self> {
46 let config = crate::config::Config::global();
47 let api_key: String = config.get_secret("TETRATE_API_KEY")?;
48 let host: String = config
50 .get_param("TETRATE_HOST")
51 .unwrap_or_else(|_| "https://api.router.tetrate.ai".to_string());
52
53 let auth = AuthMethod::BearerToken(api_key);
54 let api_client = ApiClient::new(host, auth)?
55 .with_header("HTTP-Referer", "https://astercloud.github.io/aster-rust")?
56 .with_header("X-Title", "aster")?;
57
58 Ok(Self {
59 api_client,
60 model,
61 supports_streaming: true,
62 name: Self::metadata().name,
63 })
64 }
65
66 async fn post(&self, payload: &Value) -> Result<Value, ProviderError> {
67 let response = self
68 .api_client
69 .response_post("v1/chat/completions", payload)
70 .await?;
71
72 if is_google_model(payload) {
74 return handle_response_google_compat(response).await;
75 }
76
77 let response_body = handle_response_openai_compat(response)
79 .await
80 .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?;
81
82 let _debug = format!(
83 "Tetrate Agent Router Service request with payload: {} and response: {}",
84 serde_json::to_string_pretty(payload).unwrap_or_else(|_| "Invalid JSON".to_string()),
85 serde_json::to_string_pretty(&response_body)
86 .unwrap_or_else(|_| "Invalid JSON".to_string())
87 );
88
89 if let Some(error_obj) = response_body.get("error") {
91 let error_message = error_obj
93 .get("message")
94 .and_then(|m| m.as_str())
95 .unwrap_or("Unknown Tetrate Agent Router Service error");
96
97 let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0);
98
99 if error_code == 400 && error_message.contains("maximum context length") {
101 return Err(ProviderError::ContextLengthExceeded(
102 error_message.to_string(),
103 ));
104 }
105
106 match error_code {
108 401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())),
109 429 => {
110 return Err(ProviderError::RateLimitExceeded {
111 details: error_message.to_string(),
112 retry_delay: None,
113 })
114 }
115 500 | 503 => return Err(ProviderError::ServerError(error_message.to_string())),
116 _ => return Err(ProviderError::RequestFailed(error_message.to_string())),
117 }
118 }
119
120 Ok(response_body)
122 }
123}
124
125#[async_trait]
126impl Provider for TetrateProvider {
127 fn metadata() -> ProviderMetadata {
128 ProviderMetadata::new(
129 "tetrate",
130 "Tetrate Agent Router Service",
131 "Enterprise router for AI models",
132 TETRATE_DEFAULT_MODEL,
133 TETRATE_KNOWN_MODELS.to_vec(),
134 TETRATE_DOC_URL,
135 vec![
136 ConfigKey::new("TETRATE_API_KEY", true, true, None),
137 ConfigKey::new(
138 "TETRATE_HOST",
139 false,
140 false,
141 Some("https://api.router.tetrate.ai"),
142 ),
143 ],
144 )
145 }
146
147 fn get_name(&self) -> &str {
148 &self.name
149 }
150
151 fn get_model_config(&self) -> ModelConfig {
152 self.model.clone()
153 }
154
155 #[tracing::instrument(
156 skip(self, model_config, system, messages, tools),
157 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
158 )]
159 async fn complete_with_model(
160 &self,
161 model_config: &ModelConfig,
162 system: &str,
163 messages: &[Message],
164 tools: &[Tool],
165 ) -> Result<(Message, ProviderUsage), ProviderError> {
166 let payload = create_request(
167 model_config,
168 system,
169 messages,
170 tools,
171 &super::utils::ImageFormat::OpenAi,
172 false,
173 )?;
174 let mut log = RequestLog::start(model_config, &payload)?;
175
176 let response = self
178 .with_retry(|| async {
179 let payload_clone = payload.clone();
180 self.post(&payload_clone).await
181 })
182 .await?;
183
184 let message = response_to_message(&response)?;
186 let usage = response.get("usage").map(get_usage).unwrap_or_else(|| {
187 tracing::debug!("Failed to get usage data");
188 Usage::default()
189 });
190 let model = get_model(&response);
191 log.write(&response, Some(&usage))?;
192 Ok((message, ProviderUsage::new(model, usage)))
193 }
194
195 async fn stream(
196 &self,
197 system: &str,
198 messages: &[Message],
199 tools: &[Tool],
200 ) -> Result<MessageStream, ProviderError> {
201 let payload = create_request(
202 &self.model,
203 system,
204 messages,
205 tools,
206 &super::utils::ImageFormat::OpenAi,
207 true,
208 )?;
209
210 let mut log = RequestLog::start(&self.model, &payload)?;
211
212 let response = self
213 .with_retry(|| async {
214 let resp = self
215 .api_client
216 .response_post("v1/chat/completions", &payload)
217 .await?;
218 handle_status_openai_compat(resp).await
219 })
220 .await
221 .inspect_err(|e| {
222 let _ = log.error(e);
223 })?;
224
225 stream_openai_compat(response, log)
226 }
227
228 async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
230 let response = match self.api_client.response_get("v1/models").await {
232 Ok(response) => response,
233 Err(e) => {
234 tracing::warn!("Failed to fetch models from Tetrate Agent Router Service API: {}, falling back to manual model entry", e);
235 return Ok(None);
236 }
237 };
238
239 let json: serde_json::Value = match response.json().await {
241 Ok(json) => json,
242 Err(e) => {
243 tracing::warn!("Failed to parse Tetrate Agent Router Service API response as JSON: {}, falling back to manual model entry", e);
244 return Ok(None);
245 }
246 };
247
248 if let Some(err_obj) = json.get("error") {
250 let msg = err_obj
251 .get("message")
252 .and_then(|v| v.as_str())
253 .unwrap_or("unknown error");
254 tracing::warn!(
255 "Tetrate Agent Router Service API returned an error: {}",
256 msg
257 );
258 return Ok(None);
259 }
260
261 let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
264 ProviderError::UsageError("Missing data field in JSON response".into())
265 })?;
266
267 let mut models: Vec<String> = data
268 .iter()
269 .filter_map(|model| {
270 let id = model.get("id").and_then(|v| v.as_str())?;
272
273 let supports_computer_use = model
276 .get("supports_computer_use")
277 .and_then(|v| v.as_bool())
278 .unwrap_or(false);
279
280 if supports_computer_use {
281 Some(id.to_string())
282 } else {
283 tracing::debug!(
284 "Model '{}' does not support computer_use (tool support), skipping",
285 id
286 );
287 None
288 }
289 })
290 .collect();
291
292 if models.is_empty() {
294 tracing::warn!("No models with tool support found in Tetrate Agent Router Service API response, falling back to manual model entry");
295 return Ok(None);
296 }
297
298 models.sort();
299 Ok(Some(models))
300 }
301
302 fn supports_streaming(&self) -> bool {
303 self.supports_streaming
304 }
305}