Skip to main content

aster/providers/
utils.rs

1use super::base::{MessageStream, Usage};
2use super::errors::GoogleErrorCode;
3use crate::config::paths::Paths;
4use crate::model::ModelConfig;
5use crate::providers::errors::ProviderError;
6use crate::providers::formats::openai::response_to_streaming_message;
7use anyhow::{anyhow, Result};
8use async_stream::try_stream;
9use base64::Engine;
10use futures::TryStreamExt;
11use regex::Regex;
12use reqwest::{Response, StatusCode};
13use rmcp::model::{AnnotateAble, ImageContent, RawImageContent};
14use serde::{Deserialize, Serialize};
15use serde_json::{json, Value};
16use std::fmt::Display;
17use std::fs::File;
18use std::io;
19use std::io::{BufWriter, Read, Write};
20use std::path::{Path, PathBuf};
21use std::sync::OnceLock;
22use std::time::Duration;
23use tokio::pin;
24use tokio_stream::StreamExt;
25use tokio_util::codec::{FramedRead, LinesCodec};
26use tokio_util::io::StreamReader;
27use uuid::Uuid;
28
29#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
30pub enum ImageFormat {
31    OpenAi,
32    Anthropic,
33}
34
35/// Convert an image content into an image json based on format
36pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value {
37    match image_format {
38        ImageFormat::OpenAi => json!({
39            "type": "image_url",
40            "image_url": {
41                "url": format!("data:{};base64,{}", image.mime_type, image.data)
42            }
43        }),
44        ImageFormat::Anthropic => json!({
45            "type": "image",
46            "source": {
47                "type": "base64",
48                "media_type": image.mime_type,
49                "data": image.data,
50            }
51        }),
52    }
53}
54
55pub fn filter_extensions_from_system_prompt(system: &str) -> String {
56    let Some(extensions_start) = system.find("# Extensions") else {
57        return system.to_string();
58    };
59
60    let Some(after_extensions) = system.get(extensions_start + 1..) else {
61        return system.to_string();
62    };
63
64    if let Some(next_section_pos) = after_extensions.find("\n# ") {
65        let Some(before) = system.get(..extensions_start) else {
66            return system.to_string();
67        };
68        let Some(after) = system.get(extensions_start + next_section_pos + 1..) else {
69            return system.to_string();
70        };
71        format!("{}{}", before.trim_end(), after)
72    } else {
73        system
74            .get(..extensions_start)
75            .map(|s| s.trim_end().to_string())
76            .unwrap_or_else(|| system.to_string())
77    }
78}
79
80fn check_context_length_exceeded(text: &str) -> bool {
81    let check_phrases = [
82        "too long",
83        "context length",
84        "context_length_exceeded",
85        "reduce the length",
86        "token count",
87        "exceeds",
88        "exceed context limit",
89        "input length",
90        "max_tokens",
91        "decrease input length",
92        "context limit",
93    ];
94    let text_lower = text.to_lowercase();
95    check_phrases
96        .iter()
97        .any(|phrase| text_lower.contains(phrase))
98}
99
100fn format_server_error_message(status_code: StatusCode, payload: Option<&Value>) -> String {
101    match payload {
102        Some(Value::Null) | None => format!(
103            "HTTP {}: No response body received from server",
104            status_code.as_u16()
105        ),
106        Some(p) => format!("HTTP {}: {}", status_code.as_u16(), p),
107    }
108}
109
110pub fn map_http_error_to_provider_error(
111    status: StatusCode,
112    payload: Option<Value>,
113) -> ProviderError {
114    let extract_message = || -> String {
115        payload
116            .as_ref()
117            .and_then(|p| {
118                p.get("error")
119                    .and_then(|e| e.get("message"))
120                    .or_else(|| p.get("message"))
121                    .and_then(|m| m.as_str())
122                    .map(String::from)
123            })
124            .unwrap_or_else(|| payload.as_ref().map(|p| p.to_string()).unwrap_or_default())
125    };
126
127    let error = match status {
128        StatusCode::OK => unreachable!("Should not call this function with OK status"),
129        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ProviderError::Authentication(format!(
130            "Authentication failed. Status: {}. Response: {}",
131            status,
132            extract_message()
133        )),
134        StatusCode::NOT_FOUND => {
135            ProviderError::RequestFailed(format!("Resource not found (404): {}", extract_message()))
136        }
137        StatusCode::PAYLOAD_TOO_LARGE => ProviderError::ContextLengthExceeded(extract_message()),
138        StatusCode::BAD_REQUEST => {
139            let payload_str = extract_message();
140            if check_context_length_exceeded(&payload_str) {
141                ProviderError::ContextLengthExceeded(payload_str)
142            } else {
143                ProviderError::RequestFailed(format!("Bad request (400): {}", payload_str))
144            }
145        }
146        StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded {
147            details: extract_message(),
148            retry_delay: None,
149        },
150        _ if status.is_server_error() => {
151            ProviderError::ServerError(format!("Server error ({}): {}", status, extract_message()))
152        }
153        _ => ProviderError::RequestFailed(format!(
154            "Request failed with status {}: {}",
155            status,
156            extract_message()
157        )),
158    };
159
160    if !status.is_success() {
161        tracing::warn!(
162            "Provider request failed with status: {}. Payload: {:?}. Returning error: {:?}",
163            status,
164            payload,
165            error
166        );
167    }
168
169    error
170}
171
172pub async fn handle_status_openai_compat(response: Response) -> Result<Response, ProviderError> {
173    let status = response.status();
174    if !status.is_success() {
175        let body = response.text().await.unwrap_or_default();
176        let payload = serde_json::from_str::<Value>(&body).ok();
177        return Err(map_http_error_to_provider_error(status, payload));
178    }
179    Ok(response)
180}
181
182pub async fn handle_response_openai_compat(response: Response) -> Result<Value, ProviderError> {
183    let response = handle_status_openai_compat(response).await?;
184
185    response.json::<Value>().await.map_err(|e| {
186        ProviderError::RequestFailed(format!("Response body is not valid JSON: {}", e))
187    })
188}
189
190pub fn stream_openai_compat(
191    response: Response,
192    mut log: RequestLog,
193) -> Result<MessageStream, ProviderError> {
194    let stream = response.bytes_stream().map_err(io::Error::other);
195
196    Ok(Box::pin(try_stream! {
197        let stream_reader = StreamReader::new(stream);
198        let framed = FramedRead::new(stream_reader, LinesCodec::new())
199            .map_err(anyhow::Error::from);
200
201        let message_stream = response_to_streaming_message(framed);
202        pin!(message_stream);
203        while let Some(message) = message_stream.next().await {
204            let (message, usage) = message.map_err(|e|
205                ProviderError::RequestFailed(format!("Stream decode error: {}", e))
206            )?;
207            log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
208            yield (message, usage);
209        }
210    }))
211}
212
213pub fn is_google_model(payload: &Value) -> bool {
214    payload
215        .get("model")
216        .and_then(|m| m.as_str())
217        .unwrap_or("")
218        .to_lowercase()
219        .contains("google")
220}
221
222/// Extracts `StatusCode` from response status or payload error code.
223/// This function first checks the status code of the response. If the status is successful (2xx),
224/// it then checks the payload for any error codes and maps them to appropriate `StatusCode`.
225/// If the status is not successful (e.g., 4xx or 5xx), the original status code is returned.
226fn get_google_final_status(status: StatusCode, payload: Option<&Value>) -> StatusCode {
227    // If the status is successful, check for an error in the payload
228    if status.is_success() {
229        if let Some(payload) = payload {
230            if let Some(error) = payload.get("error") {
231                if let Some(code) = error.get("code").and_then(|c| c.as_u64()) {
232                    if let Some(google_error) = GoogleErrorCode::from_code(code) {
233                        return google_error.to_status_code();
234                    }
235                }
236            }
237        }
238    }
239    status
240}
241
242fn parse_google_retry_delay(payload: &Value) -> Option<Duration> {
243    payload
244        .get("error")
245        .and_then(|error| error.get("details"))
246        .and_then(|details| details.as_array())
247        .and_then(|details_array| {
248            details_array.iter().find_map(|detail| {
249                if detail
250                    .get("@type")
251                    .and_then(|t| t.as_str())
252                    .is_some_and(|s| s.ends_with("RetryInfo"))
253                {
254                    detail
255                        .get("retryDelay")
256                        .and_then(|delay| delay.as_str())
257                        .and_then(|s| s.strip_suffix('s'))
258                        .and_then(|num| num.parse::<u64>().ok())
259                        .map(Duration::from_secs)
260                } else {
261                    None
262                }
263            })
264        })
265}
266
267/// Handle response from Google Gemini API-compatible endpoints.
268///
269/// Processes HTTP responses, handling specific statuses and parsing the payload
270/// for error messages. Logs the response payload for debugging purposes.
271///
272/// ### References
273/// - Error Codes: https://ai.google.dev/gemini-api/docs/troubleshooting?lang=python
274///
275/// ### Arguments
276/// - `response`: The HTTP response to process.
277///
278/// ### Returns
279/// - `Ok(Value)`: Parsed JSON on success.
280/// - `Err(ProviderError)`: Describes the failure reason.
281pub async fn handle_response_google_compat(response: Response) -> Result<Value, ProviderError> {
282    let status = response.status();
283    let payload: Option<Value> = response.json().await.ok();
284    let final_status = get_google_final_status(status, payload.as_ref());
285
286    match final_status {
287        StatusCode::OK =>  payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
288        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
289            Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
290                Status: {}. Response: {:?}", final_status, payload )))
291        }
292        StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => {
293            let mut error_msg = "Unknown error".to_string();
294            if let Some(payload) = &payload {
295                if let Some(error) = payload.get("error") {
296                    error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
297                    let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status");
298                    if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") {
299                        return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
300                    }
301                }
302            }
303            tracing::debug!(
304                "{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
305            );
306            Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", final_status, error_msg)))
307        }
308        StatusCode::TOO_MANY_REQUESTS => {
309            let retry_delay = payload.as_ref().and_then(parse_google_retry_delay);
310            Err(ProviderError::RateLimitExceeded {
311                details: format!("{:?}", payload),
312                retry_delay,
313            })
314        }
315        _ if final_status.is_server_error() => Err(ProviderError::ServerError(
316            format_server_error_message(final_status, payload.as_ref()),
317        )),
318        _ => {
319            tracing::debug!(
320                "{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
321            );
322            Err(ProviderError::RequestFailed(format!("Request failed with status: {}", final_status)))
323        }
324    }
325}
326
327pub fn sanitize_function_name(name: &str) -> String {
328    static RE: OnceLock<Regex> = OnceLock::new();
329    let re = RE.get_or_init(|| Regex::new(r"[^a-zA-Z0-9_-]").unwrap());
330    re.replace_all(name, "_").to_string()
331}
332
333pub fn is_valid_function_name(name: &str) -> bool {
334    static RE: OnceLock<Regex> = OnceLock::new();
335    let re = RE.get_or_init(|| Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
336    re.is_match(name)
337}
338
339/// Extract the model name from a JSON object. Common with most providers to have this top level attribute.
340pub fn get_model(data: &Value) -> String {
341    if let Some(model) = data.get("model") {
342        if let Some(model_str) = model.as_str() {
343            model_str.to_string()
344        } else {
345            "Unknown".to_string()
346        }
347    } else {
348        "Unknown".to_string()
349    }
350}
351
352/// Check if a file is actually an image by examining its magic bytes
353fn is_image_file(path: &Path) -> bool {
354    if let Ok(mut file) = std::fs::File::open(path) {
355        let mut buffer = [0u8; 8]; // Large enough for most image magic numbers
356        if file.read(&mut buffer).is_ok() {
357            // Check magic numbers for common image formats
358            return match &buffer[0..4] {
359                // PNG: 89 50 4E 47
360                [0x89, 0x50, 0x4E, 0x47] => true,
361                // JPEG: FF D8 FF
362                [0xFF, 0xD8, 0xFF, _] => true,
363                // GIF: 47 49 46 38
364                [0x47, 0x49, 0x46, 0x38] => true,
365                _ => false,
366            };
367        }
368    }
369    false
370}
371
372/// Detect if a string contains a path to an image file
373pub fn detect_image_path(text: &str) -> Option<&str> {
374    // Basic image file extension check
375    let extensions = [".png", ".jpg", ".jpeg"];
376
377    // Find any word that ends with an image extension
378    for word in text.split_whitespace() {
379        if extensions
380            .iter()
381            .any(|ext| word.to_lowercase().ends_with(ext))
382        {
383            let path = Path::new(word);
384            // Check if it's an absolute path and file exists
385            if path.is_absolute() && path.is_file() {
386                // Verify it's actually an image file
387                if is_image_file(path) {
388                    return Some(word);
389                }
390            }
391        }
392    }
393    None
394}
395
396/// Convert a local image file to base64 encoded ImageContent
397pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
398    let path = Path::new(path);
399
400    // Verify it's an image before proceeding
401    if !is_image_file(path) {
402        return Err(ProviderError::RequestFailed(
403            "File is not a valid image".to_string(),
404        ));
405    }
406
407    // Read the file
408    let bytes = std::fs::read(path)
409        .map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?;
410
411    // Detect mime type from extension
412    let mime_type = match path.extension().and_then(|e| e.to_str()) {
413        Some(ext) => match ext.to_lowercase().as_str() {
414            "png" => "image/png",
415            "jpg" | "jpeg" => "image/jpeg",
416            _ => {
417                return Err(ProviderError::RequestFailed(
418                    "Unsupported image format".to_string(),
419                ))
420            }
421        },
422        None => {
423            return Err(ProviderError::RequestFailed(
424                "Unknown image format".to_string(),
425            ))
426        }
427    };
428
429    // Convert to base64
430    let data = base64::prelude::BASE64_STANDARD.encode(&bytes);
431
432    Ok(RawImageContent {
433        mime_type: mime_type.to_string(),
434        data,
435        meta: None,
436    }
437    .no_annotation())
438}
439
440pub fn unescape_json_values(value: &Value) -> Value {
441    let mut cloned = value.clone();
442    unescape_json_values_in_place(&mut cloned);
443    cloned
444}
445
446fn unescape_json_values_in_place(value: &mut Value) {
447    match value {
448        Value::Object(map) => {
449            for v in map.values_mut() {
450                unescape_json_values_in_place(v);
451            }
452        }
453        Value::Array(arr) => {
454            for v in arr.iter_mut() {
455                unescape_json_values_in_place(v);
456            }
457        }
458        Value::String(s) => {
459            if s.contains('\\') {
460                *s = s
461                    .replace("\\\\n", "\n")
462                    .replace("\\\\t", "\t")
463                    .replace("\\\\r", "\r")
464                    .replace("\\\\\"", "\"")
465                    .replace("\\n", "\n")
466                    .replace("\\t", "\t")
467                    .replace("\\r", "\r")
468                    .replace("\\\"", "\"");
469            }
470        }
471        _ => {}
472    }
473}
474
475pub struct RequestLog {
476    writer: Option<BufWriter<File>>,
477    temp_path: PathBuf,
478}
479
480pub const LOGS_TO_KEEP: usize = 10;
481
482impl RequestLog {
483    pub fn start<Payload>(model_config: &ModelConfig, payload: &Payload) -> Result<Self>
484    where
485        Payload: Serialize,
486    {
487        let logs_dir = Paths::in_state_dir("logs");
488
489        let request_id = Uuid::new_v4();
490        let temp_name = format!("llm_request.{request_id}.jsonl");
491        let temp_path = logs_dir.join(PathBuf::from(temp_name));
492
493        let mut writer = BufWriter::new(
494            File::options()
495                .write(true)
496                .create(true)
497                .truncate(true)
498                .open(&temp_path)?,
499        );
500
501        let data = serde_json::json!({
502            "model_config": model_config,
503            "input": payload,
504        });
505        writeln!(writer, "{}", serde_json::to_string(&data)?)?;
506
507        Ok(Self {
508            writer: Some(writer),
509            temp_path,
510        })
511    }
512
513    fn write_json(&mut self, line: &serde_json::Value) -> Result<()> {
514        let writer = self
515            .writer
516            .as_mut()
517            .ok_or_else(|| anyhow!("logger is finished"))?;
518        writeln!(writer, "{}", serde_json::to_string(line)?)?;
519        Ok(())
520    }
521
522    pub fn error<E>(&mut self, error: E) -> Result<()>
523    where
524        E: Display,
525    {
526        self.write_json(&serde_json::json!({
527            "error": format!("{}", error),
528        }))
529    }
530
531    pub fn write<Payload>(&mut self, data: &Payload, usage: Option<&Usage>) -> Result<()>
532    where
533        Payload: Serialize,
534    {
535        self.write_json(&serde_json::json!({
536            "data": data,
537            "usage": usage,
538        }))
539    }
540
541    fn finish(&mut self) -> Result<()> {
542        if let Some(mut writer) = self.writer.take() {
543            writer.flush()?;
544            let logs_dir = Paths::in_state_dir("logs");
545            let log_path = |i| logs_dir.join(format!("llm_request.{}.jsonl", i));
546
547            for i in (0..LOGS_TO_KEEP - 1).rev() {
548                let _ = std::fs::rename(log_path(i), log_path(i + 1));
549            }
550
551            std::fs::rename(&self.temp_path, log_path(0))?;
552        }
553        Ok(())
554    }
555}
556
557impl Drop for RequestLog {
558    fn drop(&mut self) {
559        if std::thread::panicking() {
560            return;
561        }
562        let _ = self.finish();
563    }
564}
565
566/// Safely parse a JSON string that may contain doubly-encoded or malformed JSON.
567/// This function first attempts to parse the input string as-is. If that fails,
568/// it applies control character escaping and tries again.
569///
570/// This approach preserves valid JSON like `{"key1": "value1",\n"key2": "value"}`
571/// (which contains a literal \n but is perfectly valid JSON) while still fixing
572/// broken JSON like `{"key1": "value1\n","key2": "value"}` (which contains an
573/// unescaped newline character).
574pub fn safely_parse_json(s: &str) -> Result<serde_json::Value, serde_json::Error> {
575    // First, try parsing the string as-is
576    match serde_json::from_str(s) {
577        Ok(value) => Ok(value),
578        Err(_) => {
579            // If that fails, try with control character escaping
580            let escaped = json_escape_control_chars_in_string(s);
581            serde_json::from_str(&escaped)
582        }
583    }
584}
585
586/// Helper to escape control characters in a string that is supposed to be a JSON document.
587/// This function iterates through the input string `s` and replaces any literal
588/// control characters (U+0000 to U+001F) with their JSON-escaped equivalents
589/// (e.g., '\n' becomes "\\n", '\u0001' becomes "\\u0001").
590///
591/// It does NOT escape quotes (") or backslashes (\) because it assumes `s` is a
592/// full JSON document, and these characters might be structural (e.g., object delimiters,
593/// existing valid escape sequences). The goal is to fix common LLM errors where
594/// control characters are emitted raw into what should be JSON string values,
595/// making the overall JSON structure unparsable.
596///
597/// If the input string `s` has other JSON syntax errors (e.g., an unescaped quote
598/// *within* a string value like `{"key": "string with " quote"}`), this function
599/// will not fix them. It specifically targets unescaped control characters.
600pub fn json_escape_control_chars_in_string(s: &str) -> String {
601    let mut r = String::with_capacity(s.len()); // Pre-allocate for efficiency
602    for c in s.chars() {
603        match c {
604            // ASCII Control characters (U+0000 to U+001F)
605            '\u{0000}'..='\u{001F}' => {
606                match c {
607                    '\u{0008}' => r.push_str("\\b"), // Backspace
608                    '\u{000C}' => r.push_str("\\f"), // Form feed
609                    '\n' => r.push_str("\\n"),       // Line feed
610                    '\r' => r.push_str("\\r"),       // Carriage return
611                    '\t' => r.push_str("\\t"),       // Tab
612                    // Other control characters (e.g., NUL, SOH, VT, etc.)
613                    // that don't have a specific short escape sequence.
614                    _ => {
615                        r.push_str(&format!("\\u{:04x}", c as u32));
616                    }
617                }
618            }
619            // Other characters are passed through.
620            // This includes quotes (") and backslashes (\). If these are part of the
621            // JSON structure (e.g. {"key": "value"}) or part of an already correctly
622            // escaped sequence within a string value (e.g. "string with \\\" quote"),
623            // they are preserved as is. This function does not attempt to fix
624            // malformed quote or backslash usage *within* string values if the LLM
625            // generates them incorrectly (e.g. {"key": "unescaped " quote in string"}).
626            _ => r.push(c),
627        }
628    }
629    r
630}
631
632#[cfg(test)]
633mod tests {
634    use super::*;
635    use serde_json::json;
636
637    #[test]
638    fn test_detect_image_path() {
639        // Create a temporary PNG file with valid PNG magic numbers
640        let temp_dir = tempfile::tempdir().unwrap();
641        let png_path = temp_dir.path().join("test.png");
642        let png_data = [
643            0x89, 0x50, 0x4E, 0x47, // PNG magic number
644            0x0D, 0x0A, 0x1A, 0x0A, // PNG header
645            0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
646        ];
647        std::fs::write(&png_path, png_data).unwrap();
648        let png_path_str = png_path.to_str().unwrap();
649
650        // Create a fake PNG (wrong magic numbers)
651        let fake_png_path = temp_dir.path().join("fake.png");
652        std::fs::write(&fake_png_path, b"not a real png").unwrap();
653
654        // Test with valid PNG file using absolute path
655        let text = format!("Here is an image {}", png_path_str);
656        assert_eq!(detect_image_path(&text), Some(png_path_str));
657
658        // Test with non-image file that has .png extension
659        let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
660        assert_eq!(detect_image_path(&text), None);
661
662        // Test with non-existent file
663        let text = "Here is a fake.png that doesn't exist";
664        assert_eq!(detect_image_path(text), None);
665
666        // Test with non-image file
667        let text = "Here is a file.txt";
668        assert_eq!(detect_image_path(text), None);
669
670        // Test with relative path (should not match)
671        let text = "Here is a relative/path/image.png";
672        assert_eq!(detect_image_path(text), None);
673    }
674
675    #[test]
676    fn test_load_image_file() {
677        // Create a temporary PNG file with valid PNG magic numbers
678        let temp_dir = tempfile::tempdir().unwrap();
679        let png_path = temp_dir.path().join("test.png");
680        let png_data = [
681            0x89, 0x50, 0x4E, 0x47, // PNG magic number
682            0x0D, 0x0A, 0x1A, 0x0A, // PNG header
683            0x00, 0x00, 0x00, 0x0D, // Rest of fake PNG data
684        ];
685        std::fs::write(&png_path, png_data).unwrap();
686        let png_path_str = png_path.to_str().unwrap();
687
688        // Create a fake PNG (wrong magic numbers)
689        let fake_png_path = temp_dir.path().join("fake.png");
690        std::fs::write(&fake_png_path, b"not a real png").unwrap();
691        let fake_png_path_str = fake_png_path.to_str().unwrap();
692
693        // Test loading valid PNG file
694        let result = load_image_file(png_path_str);
695        assert!(result.is_ok());
696        let image = result.unwrap();
697        assert_eq!(image.mime_type, "image/png");
698
699        // Test loading fake PNG file
700        let result = load_image_file(fake_png_path_str);
701        assert!(result.is_err());
702        assert!(result
703            .unwrap_err()
704            .to_string()
705            .contains("not a valid image"));
706
707        // Test non-existent file
708        let result = load_image_file("nonexistent.png");
709        assert!(result.is_err());
710
711        // Create a GIF file with valid header bytes
712        let gif_path = temp_dir.path().join("test.gif");
713        // Minimal GIF89a header
714        let gif_data = [0x47, 0x49, 0x46, 0x38, 0x39, 0x61];
715        std::fs::write(&gif_path, gif_data).unwrap();
716        let gif_path_str = gif_path.to_str().unwrap();
717
718        // Test loading unsupported GIF format
719        let result = load_image_file(gif_path_str);
720        assert!(result.is_err());
721        assert!(result
722            .unwrap_err()
723            .to_string()
724            .contains("Unsupported image format"));
725    }
726
727    #[test]
728    fn test_sanitize_function_name() {
729        assert_eq!(sanitize_function_name("hello-world"), "hello-world");
730        assert_eq!(sanitize_function_name("hello world"), "hello_world");
731        assert_eq!(sanitize_function_name("hello@world"), "hello_world");
732    }
733
734    #[test]
735    fn test_is_valid_function_name() {
736        assert!(is_valid_function_name("hello-world"));
737        assert!(is_valid_function_name("hello_world"));
738        assert!(!is_valid_function_name("hello world"));
739        assert!(!is_valid_function_name("hello@world"));
740    }
741
742    #[test]
743    fn unescape_json_values_with_object() {
744        let value = json!({"text": "Hello\\nWorld"});
745        let unescaped_value = unescape_json_values(&value);
746        assert_eq!(unescaped_value, json!({"text": "Hello\nWorld"}));
747    }
748
749    #[test]
750    fn unescape_json_values_with_array() {
751        let value = json!(["Hello\\nWorld", "Goodbye\\tWorld"]);
752        let unescaped_value = unescape_json_values(&value);
753        assert_eq!(unescaped_value, json!(["Hello\nWorld", "Goodbye\tWorld"]));
754    }
755
756    #[test]
757    fn unescape_json_values_with_string() {
758        let value = json!("Hello\\nWorld");
759        let unescaped_value = unescape_json_values(&value);
760        assert_eq!(unescaped_value, json!("Hello\nWorld"));
761    }
762
763    #[test]
764    fn unescape_json_values_with_mixed_content() {
765        let value = json!({
766            "text": "Hello\\nWorld\\\\n!",
767            "array": ["Goodbye\\tWorld", "See you\\rlater"],
768            "nested": {
769                "inner_text": "Inner\\\"Quote\\\""
770            }
771        });
772        let unescaped_value = unescape_json_values(&value);
773        assert_eq!(
774            unescaped_value,
775            json!({
776                "text": "Hello\nWorld\n!",
777                "array": ["Goodbye\tWorld", "See you\rlater"],
778                "nested": {
779                    "inner_text": "Inner\"Quote\""
780                }
781            })
782        );
783    }
784
785    #[test]
786    fn unescape_json_values_with_no_escapes() {
787        let value = json!({"text": "Hello World"});
788        let unescaped_value = unescape_json_values(&value);
789        assert_eq!(unescaped_value, json!({"text": "Hello World"}));
790    }
791
792    #[test]
793    fn test_is_google_model() {
794        // Define the test cases as a vector of tuples
795        let test_cases = vec![
796            // (input, expected_result)
797            (json!({ "model": "google_gemini" }), true),
798            (json!({ "model": "microsoft_bing" }), false),
799            (json!({ "model": "" }), false),
800            (json!({}), false),
801            (json!({ "model": "Google_XYZ" }), true),
802            (json!({ "model": "google_abc" }), true),
803        ];
804
805        // Iterate through each test case and assert the result
806        for (payload, expected_result) in test_cases {
807            assert_eq!(is_google_model(&payload), expected_result);
808        }
809    }
810
811    #[test]
812    fn test_get_google_final_status_success() {
813        let status = StatusCode::OK;
814        let payload = json!({});
815        let result = get_google_final_status(status, Some(&payload));
816        assert_eq!(result, StatusCode::OK);
817    }
818
819    #[test]
820    fn test_get_google_final_status_with_error_code() {
821        // Test error code mappings for different payload error codes
822        let test_cases = vec![
823            // (error code, status, expected status code)
824            (200, None, StatusCode::OK),
825            (429, Some(StatusCode::OK), StatusCode::TOO_MANY_REQUESTS),
826            (400, Some(StatusCode::OK), StatusCode::BAD_REQUEST),
827            (401, Some(StatusCode::OK), StatusCode::UNAUTHORIZED),
828            (403, Some(StatusCode::OK), StatusCode::FORBIDDEN),
829            (404, Some(StatusCode::OK), StatusCode::NOT_FOUND),
830            (500, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
831            (503, Some(StatusCode::OK), StatusCode::SERVICE_UNAVAILABLE),
832            (999, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
833            (500, Some(StatusCode::BAD_REQUEST), StatusCode::BAD_REQUEST),
834            (
835                404,
836                Some(StatusCode::INTERNAL_SERVER_ERROR),
837                StatusCode::INTERNAL_SERVER_ERROR,
838            ),
839        ];
840
841        for (error_code, status, expected_status) in test_cases {
842            let payload = if let Some(_status) = status {
843                json!({
844                    "error": {
845                        "code": error_code,
846                        "message": "Error message"
847                    }
848                })
849            } else {
850                json!({})
851            };
852
853            let result = get_google_final_status(status.unwrap_or(StatusCode::OK), Some(&payload));
854            assert_eq!(result, expected_status);
855        }
856    }
857
858    #[test]
859    fn test_safely_parse_json() {
860        // Test valid JSON that should parse without escaping (contains proper escape sequence)
861        let valid_json = r#"{"key1": "value1","key2": "value2"}"#;
862        let result = safely_parse_json(valid_json).unwrap();
863        assert_eq!(result["key1"], "value1");
864        assert_eq!(result["key2"], "value2");
865
866        // Test JSON with actual unescaped newlines that needs escaping
867        let invalid_json = "{\"key1\": \"value1\n\",\"key2\": \"value2\"}";
868        let result = safely_parse_json(invalid_json).unwrap();
869        assert_eq!(result["key1"], "value1\n");
870        assert_eq!(result["key2"], "value2");
871
872        // Test already valid JSON - should parse on first try
873        let good_json = r#"{"test": "value"}"#;
874        let result = safely_parse_json(good_json).unwrap();
875        assert_eq!(result["test"], "value");
876
877        // Test completely invalid JSON that can't be fixed
878        let broken_json = r#"{"key": "unclosed_string"#;
879        assert!(safely_parse_json(broken_json).is_err());
880
881        // Test empty object
882        let empty_json = "{}";
883        let result = safely_parse_json(empty_json).unwrap();
884        assert!(result.as_object().unwrap().is_empty());
885
886        // Test JSON with escaped newlines (valid JSON) - should parse on first try
887        let escaped_json = r#"{"key": "value with\nnewline"}"#;
888        let result = safely_parse_json(escaped_json).unwrap();
889        assert_eq!(result["key"], "value with\nnewline");
890    }
891
892    #[test]
893    fn test_json_escape_control_chars_in_string() {
894        // Test basic control character escaping
895        assert_eq!(
896            json_escape_control_chars_in_string("Hello\nWorld"),
897            "Hello\\nWorld"
898        );
899        assert_eq!(
900            json_escape_control_chars_in_string("Hello\tWorld"),
901            "Hello\\tWorld"
902        );
903        assert_eq!(
904            json_escape_control_chars_in_string("Hello\rWorld"),
905            "Hello\\rWorld"
906        );
907
908        // Test multiple control characters
909        assert_eq!(
910            json_escape_control_chars_in_string("Hello\n\tWorld\r"),
911            "Hello\\n\\tWorld\\r"
912        );
913
914        // Test that quotes and backslashes are preserved (not escaped)
915        assert_eq!(
916            json_escape_control_chars_in_string("Hello \"World\""),
917            "Hello \"World\""
918        );
919        assert_eq!(
920            json_escape_control_chars_in_string("Hello\\World"),
921            "Hello\\World"
922        );
923
924        // Test JSON-like string with control characters
925        assert_eq!(
926            json_escape_control_chars_in_string("{\"message\": \"Hello\nWorld\"}"),
927            "{\"message\": \"Hello\\nWorld\"}"
928        );
929
930        // Test no changes for normal strings
931        assert_eq!(
932            json_escape_control_chars_in_string("Hello World"),
933            "Hello World"
934        );
935
936        // Test other control characters get unicode escapes
937        assert_eq!(
938            json_escape_control_chars_in_string("Hello\u{0001}World"),
939            "Hello\\u0001World"
940        );
941    }
942
943    #[test]
944    fn test_parse_google_retry_delay() {
945        let payload = json!({
946            "error": {
947                "details": [
948                    {
949                        "@type": "type.googleapis.com/google.rpc.RetryInfo",
950                        "retryDelay": "42s"
951                    }
952                ]
953            }
954        });
955        assert_eq!(
956            parse_google_retry_delay(&payload),
957            Some(Duration::from_secs(42))
958        );
959    }
960}