1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "snake_case")]
11pub enum GlobalCacheStrategy {
12 ToolBased,
13 SystemPrompt,
14 None,
15}
16
17impl Default for GlobalCacheStrategy {
18 fn default() -> Self {
19 GlobalCacheStrategy::None
20 }
21}
22
23#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
25#[serde(rename_all = "lowercase")]
26pub enum ApiLogLevel {
27 Debug,
28 Info,
29 Warn,
30 Error,
31}
32
33impl Default for ApiLogLevel {
34 fn default() -> Self {
35 ApiLogLevel::Info
36 }
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ApiLogEntry {
42 pub timestamp: String,
43 pub level: ApiLogLevel,
44 pub message: String,
45 pub details: Option<serde_json::Value>,
46}
47
48impl ApiLogEntry {
49 pub fn new(level: ApiLogLevel, message: impl Into<String>) -> Self {
50 Self {
51 timestamp: chrono::Utc::now().to_rfc3339(),
52 level,
53 message: message.into(),
54 details: None,
55 }
56 }
57
58 pub fn with_details(mut self, details: serde_json::Value) -> Self {
59 self.details = Some(details);
60 self
61 }
62
63 pub fn debug(message: impl Into<String>) -> Self {
64 Self::new(ApiLogLevel::Debug, message)
65 }
66
67 pub fn info(message: impl Into<String>) -> Self {
68 Self::new(ApiLogLevel::Info, message)
69 }
70
71 pub fn warn(message: impl Into<String>) -> Self {
72 Self::new(ApiLogLevel::Warn, message)
73 }
74
75 pub fn error(message: impl Into<String>) -> Self {
76 Self::new(ApiLogLevel::Error, message)
77 }
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub struct ApiUsage {
84 pub input_tokens: i64,
85 pub output_tokens: i64,
86 #[serde(rename = "cache_read_input_tokens")]
87 pub cache_read_input_tokens: Option<i64>,
88 #[serde(rename = "cache_creation_input_tokens")]
89 pub cache_creation_input_tokens: Option<i64>,
90 pub server_tool_use: Option<ServerToolUse>,
91 pub service_tier: Option<&'static str>,
92 pub cache_creation: Option<CacheCreation>,
93 pub inference_geo: Option<&'static str>,
94 pub iterations: Option<Vec<serde_json::Value>>,
95 pub speed: Option<&'static str>,
96}
97
98#[derive(Debug, Clone, Default, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub struct ServerToolUse {
101 pub web_search_requests: i64,
102 pub web_fetch_requests: i64,
103}
104
105#[derive(Debug, Clone, Default, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub struct CacheCreation {
108 pub ephemeral_1h_input_tokens: i64,
109 pub ephemeral_5m_input_tokens: i64,
110}
111
112pub const EMPTY_USAGE: ApiUsage = ApiUsage {
116 input_tokens: 0,
117 cache_creation_input_tokens: Some(0),
118 cache_read_input_tokens: Some(0),
119 output_tokens: 0,
120 server_tool_use: Some(ServerToolUse {
121 web_search_requests: 0,
122 web_fetch_requests: 0,
123 }),
124 service_tier: Some("standard"),
125 cache_creation: Some(CacheCreation {
126 ephemeral_1h_input_tokens: 0,
127 ephemeral_5m_input_tokens: 0,
128 }),
129 inference_geo: Some(""),
130 iterations: Some(Vec::new()),
131 speed: Some("standard"),
132};
133
134#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
136#[serde(rename_all = "kebab-case")]
137pub enum KnownGateway {
138 Litellm,
139 Helicone,
140 Portkey,
141 CloudflareAiGateway,
142 Kong,
143 Braintrust,
144 Databricks,
145}
146
147fn get_gateway_fingerprints() -> HashMap<&'static str, Vec<&'static str>> {
149 let mut fingerprints = HashMap::new();
150 fingerprints.insert("litellm", vec!["x-litellm-"]);
151 fingerprints.insert("helicone", vec!["helicone-"]);
152 fingerprints.insert("portkey", vec!["x-portkey-"]);
153 fingerprints.insert("cloudflare-ai-gateway", vec!["cf-aig-"]);
154 fingerprints.insert("kong", vec!["x-kong-"]);
155 fingerprints.insert("braintrust", vec!["x-bt-"]);
156 fingerprints
157}
158
159fn get_gateway_host_suffixes() -> HashMap<&'static str, Vec<&'static str>> {
161 let mut suffixes = HashMap::new();
162 suffixes.insert(
163 "databricks",
164 vec![
165 ".cloud.databricks.com",
166 ".azuredatabricks.net",
167 ".gcp.databricks.com",
168 ],
169 );
170 suffixes
171}
172
173pub fn detect_gateway(
175 headers: Option<&HashMap<String, String>>,
176 base_url: Option<&str>,
177) -> Option<KnownGateway> {
178 if let Some(hdrs) = headers {
180 let fingerprint_map = get_gateway_fingerprints();
181 for (key, prefixes) in fingerprint_map {
182 for prefix in prefixes {
183 for hdr_name in hdrs.keys() {
184 if hdr_name.to_lowercase().starts_with(prefix) {
185 return match key {
186 "litellm" => Some(KnownGateway::Litellm),
187 "helicone" => Some(KnownGateway::Helicone),
188 "portkey" => Some(KnownGateway::Portkey),
189 "cloudflare-ai-gateway" => Some(KnownGateway::CloudflareAiGateway),
190 "kong" => Some(KnownGateway::Kong),
191 "braintrust" => Some(KnownGateway::Braintrust),
192 "databricks" => Some(KnownGateway::Databricks),
193 _ => None,
194 };
195 }
196 }
197 }
198 }
199 }
200
201 if let Some(url) = base_url {
203 if let Ok(parsed) = url::Url::parse(url) {
204 let host = parsed
205 .host_str()
206 .map(|h| h.to_lowercase())
207 .unwrap_or_default();
208 let suffix_map = get_gateway_host_suffixes();
209 for (key, suffixes) in suffix_map {
210 for suffix in suffixes {
211 if host.ends_with(suffix) {
212 return Some(KnownGateway::Databricks);
213 }
214 }
215 }
216 }
217 }
218
219 None
220}
221
222pub fn get_anthropic_env_metadata() -> serde_json::Value {
224 let mut metadata = serde_json::Map::new();
225
226 if let Ok(base_url) = std::env::var("AI_CODE_BASE_URL") {
227 metadata.insert("baseUrl".to_string(), serde_json::Value::String(base_url));
228 }
229 if let Ok(model) = std::env::var("AI_CODE_MODEL") {
230 metadata.insert("envModel".to_string(), serde_json::Value::String(model));
231 }
232 if let Ok(small_fast_model) = std::env::var("AI_CODE_SMALL_FAST_MODEL") {
233 metadata.insert(
234 "envSmallFastModel".to_string(),
235 serde_json::Value::String(small_fast_model),
236 );
237 }
238
239 serde_json::Value::Object(metadata)
240}
241
242pub fn get_build_age_minutes() -> Option<i64> {
244 None
246}
247
248pub fn is_non_interactive_session() -> bool {
250 std::env::var("AI_CODE_NON_INTERACTIVE")
251 .map(|v| v == "1" || v.to_lowercase() == "true")
252 .unwrap_or(false)
253}
254
255pub fn get_api_provider_for_statsig() -> String {
257 std::env::var("AI_CODE_PROVIDER").unwrap_or_else(|_| "firstParty".to_string())
258}
259
260pub fn log_api_query(model: &str, messages_length: usize, temperature: f64, query_source: &str) {
262 log::debug!(
263 "[API Query] model={}, messages={}, temp={}, source={}",
264 model,
265 messages_length,
266 temperature,
267 query_source
268 );
269}
270
271pub fn log_api_error(
273 error_message: &str,
274 model: &str,
275 message_count: usize,
276 duration_ms: u64,
277 attempt: u32,
278 status: Option<u16>,
279 error_type: &str,
280) {
281 log::error!(
282 "[API Error] model={}, status={:?}, error={}, attempt={}, duration_ms={}",
283 model,
284 status,
285 error_message,
286 attempt,
287 duration_ms
288 );
289}
290
291pub fn log_api_success(
293 model: &str,
294 message_count: usize,
295 message_tokens: i64,
296 usage: &ApiUsage,
297 duration_ms: u64,
298 attempt: u32,
299 request_id: Option<&str>,
300 stop_reason: Option<&str>,
301 cost_usd: f64,
302 query_source: &str,
303) {
304 let input_tokens = usage.input_tokens;
305 let output_tokens = usage.output_tokens;
306 let cached_tokens = usage.cache_read_input_tokens.unwrap_or(0);
307 let uncached_tokens = usage.cache_creation_input_tokens.unwrap_or(0);
308
309 log::debug!(
310 "[API Success] model={}, input={}, output={}, cached={}, uncached={}, duration_ms={}, attempt={}, reason={:?}, cost=${:.4}",
311 model,
312 input_tokens,
313 output_tokens,
314 cached_tokens,
315 uncached_tokens,
316 duration_ms,
317 attempt,
318 stop_reason,
319 cost_usd
320 );
321}
322
323pub struct ApiLogger {
325 enabled: bool,
326 min_level: ApiLogLevel,
327}
328
329impl ApiLogger {
330 pub fn new() -> Self {
331 Self {
332 enabled: true,
333 min_level: ApiLogLevel::Info,
334 }
335 }
336
337 pub fn set_enabled(&mut self, enabled: bool) {
338 self.enabled = enabled;
339 }
340
341 pub fn set_min_level(&mut self, level: ApiLogLevel) {
342 self.min_level = level;
343 }
344
345 pub fn log(&self, entry: &ApiLogEntry) {
346 if !self.enabled {
347 return;
348 }
349
350 let level_priority = match entry.level {
351 ApiLogLevel::Debug => 0,
352 ApiLogLevel::Info => 1,
353 ApiLogLevel::Warn => 2,
354 ApiLogLevel::Error => 3,
355 };
356
357 let min_priority = match self.min_level {
358 ApiLogLevel::Debug => 0,
359 ApiLogLevel::Info => 1,
360 ApiLogLevel::Warn => 2,
361 ApiLogLevel::Error => 3,
362 };
363
364 if level_priority >= min_priority {
365 eprintln!("[API] {:?}: {}", entry.level, entry.message);
366 }
367 }
368}
369
370impl Default for ApiLogger {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_api_log_entry_creation() {
382 let entry = ApiLogEntry::info("test message");
383 assert_eq!(entry.level, ApiLogLevel::Info);
384 assert_eq!(entry.message, "test message");
385 assert!(entry.details.is_none());
386 }
387
388 #[test]
389 fn test_api_log_entry_with_details() {
390 let entry = ApiLogEntry::info("test").with_details(serde_json::json!({"key": "value"}));
391 assert!(entry.details.is_some());
392 }
393}