1use super::base::{MessageStream, Usage};
2use super::errors::GoogleErrorCode;
3use crate::config::paths::Paths;
4use crate::model::ModelConfig;
5use crate::providers::errors::ProviderError;
6use crate::providers::formats::openai::response_to_streaming_message;
7use anyhow::{anyhow, Result};
8use async_stream::try_stream;
9use base64::Engine;
10use futures::TryStreamExt;
11use regex::Regex;
12use reqwest::{Response, StatusCode};
13use rmcp::model::{AnnotateAble, ImageContent, RawImageContent};
14use serde::{Deserialize, Serialize};
15use serde_json::{json, Value};
16use std::fmt::Display;
17use std::fs::File;
18use std::io;
19use std::io::{BufWriter, Read, Write};
20use std::path::{Path, PathBuf};
21use std::sync::OnceLock;
22use std::time::Duration;
23use tokio::pin;
24use tokio_stream::StreamExt;
25use tokio_util::codec::{FramedRead, LinesCodec};
26use tokio_util::io::StreamReader;
27use uuid::Uuid;
28
29#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
30pub enum ImageFormat {
31 OpenAi,
32 Anthropic,
33}
34
35pub fn convert_image(image: &ImageContent, image_format: &ImageFormat) -> Value {
37 match image_format {
38 ImageFormat::OpenAi => json!({
39 "type": "image_url",
40 "image_url": {
41 "url": format!("data:{};base64,{}", image.mime_type, image.data)
42 }
43 }),
44 ImageFormat::Anthropic => json!({
45 "type": "image",
46 "source": {
47 "type": "base64",
48 "media_type": image.mime_type,
49 "data": image.data,
50 }
51 }),
52 }
53}
54
55pub fn filter_extensions_from_system_prompt(system: &str) -> String {
56 let Some(extensions_start) = system.find("# Extensions") else {
57 return system.to_string();
58 };
59
60 let Some(after_extensions) = system.get(extensions_start + 1..) else {
61 return system.to_string();
62 };
63
64 if let Some(next_section_pos) = after_extensions.find("\n# ") {
65 let Some(before) = system.get(..extensions_start) else {
66 return system.to_string();
67 };
68 let Some(after) = system.get(extensions_start + next_section_pos + 1..) else {
69 return system.to_string();
70 };
71 format!("{}{}", before.trim_end(), after)
72 } else {
73 system
74 .get(..extensions_start)
75 .map(|s| s.trim_end().to_string())
76 .unwrap_or_else(|| system.to_string())
77 }
78}
79
80fn check_context_length_exceeded(text: &str) -> bool {
81 let check_phrases = [
82 "too long",
83 "context length",
84 "context_length_exceeded",
85 "reduce the length",
86 "token count",
87 "exceeds",
88 "exceed context limit",
89 "input length",
90 "max_tokens",
91 "decrease input length",
92 "context limit",
93 ];
94 let text_lower = text.to_lowercase();
95 check_phrases
96 .iter()
97 .any(|phrase| text_lower.contains(phrase))
98}
99
100fn format_server_error_message(status_code: StatusCode, payload: Option<&Value>) -> String {
101 match payload {
102 Some(Value::Null) | None => format!(
103 "HTTP {}: No response body received from server",
104 status_code.as_u16()
105 ),
106 Some(p) => format!("HTTP {}: {}", status_code.as_u16(), p),
107 }
108}
109
110pub fn map_http_error_to_provider_error(
111 status: StatusCode,
112 payload: Option<Value>,
113) -> ProviderError {
114 let extract_message = || -> String {
115 payload
116 .as_ref()
117 .and_then(|p| {
118 p.get("error")
119 .and_then(|e| e.get("message"))
120 .or_else(|| p.get("message"))
121 .and_then(|m| m.as_str())
122 .map(String::from)
123 })
124 .unwrap_or_else(|| payload.as_ref().map(|p| p.to_string()).unwrap_or_default())
125 };
126
127 let error = match status {
128 StatusCode::OK => unreachable!("Should not call this function with OK status"),
129 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ProviderError::Authentication(format!(
130 "Authentication failed. Status: {}. Response: {}",
131 status,
132 extract_message()
133 )),
134 StatusCode::NOT_FOUND => {
135 ProviderError::RequestFailed(format!("Resource not found (404): {}", extract_message()))
136 }
137 StatusCode::PAYLOAD_TOO_LARGE => ProviderError::ContextLengthExceeded(extract_message()),
138 StatusCode::BAD_REQUEST => {
139 let payload_str = extract_message();
140 if check_context_length_exceeded(&payload_str) {
141 ProviderError::ContextLengthExceeded(payload_str)
142 } else {
143 ProviderError::RequestFailed(format!("Bad request (400): {}", payload_str))
144 }
145 }
146 StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded {
147 details: extract_message(),
148 retry_delay: None,
149 },
150 _ if status.is_server_error() => {
151 ProviderError::ServerError(format!("Server error ({}): {}", status, extract_message()))
152 }
153 _ => ProviderError::RequestFailed(format!(
154 "Request failed with status {}: {}",
155 status,
156 extract_message()
157 )),
158 };
159
160 if !status.is_success() {
161 tracing::warn!(
162 "Provider request failed with status: {}. Payload: {:?}. Returning error: {:?}",
163 status,
164 payload,
165 error
166 );
167 }
168
169 error
170}
171
172pub async fn handle_status_openai_compat(response: Response) -> Result<Response, ProviderError> {
173 let status = response.status();
174 if !status.is_success() {
175 let body = response.text().await.unwrap_or_default();
176 let payload = serde_json::from_str::<Value>(&body).ok();
177 return Err(map_http_error_to_provider_error(status, payload));
178 }
179 Ok(response)
180}
181
182pub async fn handle_response_openai_compat(response: Response) -> Result<Value, ProviderError> {
183 let response = handle_status_openai_compat(response).await?;
184
185 response.json::<Value>().await.map_err(|e| {
186 ProviderError::RequestFailed(format!("Response body is not valid JSON: {}", e))
187 })
188}
189
190pub fn stream_openai_compat(
191 response: Response,
192 mut log: RequestLog,
193) -> Result<MessageStream, ProviderError> {
194 let stream = response.bytes_stream().map_err(io::Error::other);
195
196 Ok(Box::pin(try_stream! {
197 let stream_reader = StreamReader::new(stream);
198 let framed = FramedRead::new(stream_reader, LinesCodec::new())
199 .map_err(anyhow::Error::from);
200
201 let message_stream = response_to_streaming_message(framed);
202 pin!(message_stream);
203 while let Some(message) = message_stream.next().await {
204 let (message, usage) = message.map_err(|e|
205 ProviderError::RequestFailed(format!("Stream decode error: {}", e))
206 )?;
207 log.write(&message, usage.as_ref().map(|f| f.usage).as_ref())?;
208 yield (message, usage);
209 }
210 }))
211}
212
213pub fn is_google_model(payload: &Value) -> bool {
214 payload
215 .get("model")
216 .and_then(|m| m.as_str())
217 .unwrap_or("")
218 .to_lowercase()
219 .contains("google")
220}
221
222fn get_google_final_status(status: StatusCode, payload: Option<&Value>) -> StatusCode {
227 if status.is_success() {
229 if let Some(payload) = payload {
230 if let Some(error) = payload.get("error") {
231 if let Some(code) = error.get("code").and_then(|c| c.as_u64()) {
232 if let Some(google_error) = GoogleErrorCode::from_code(code) {
233 return google_error.to_status_code();
234 }
235 }
236 }
237 }
238 }
239 status
240}
241
242fn parse_google_retry_delay(payload: &Value) -> Option<Duration> {
243 payload
244 .get("error")
245 .and_then(|error| error.get("details"))
246 .and_then(|details| details.as_array())
247 .and_then(|details_array| {
248 details_array.iter().find_map(|detail| {
249 if detail
250 .get("@type")
251 .and_then(|t| t.as_str())
252 .is_some_and(|s| s.ends_with("RetryInfo"))
253 {
254 detail
255 .get("retryDelay")
256 .and_then(|delay| delay.as_str())
257 .and_then(|s| s.strip_suffix('s'))
258 .and_then(|num| num.parse::<u64>().ok())
259 .map(Duration::from_secs)
260 } else {
261 None
262 }
263 })
264 })
265}
266
267pub async fn handle_response_google_compat(response: Response) -> Result<Value, ProviderError> {
282 let status = response.status();
283 let payload: Option<Value> = response.json().await.ok();
284 let final_status = get_google_final_status(status, payload.as_ref());
285
286 match final_status {
287 StatusCode::OK => payload.ok_or_else( || ProviderError::RequestFailed("Response body is not valid JSON".to_string()) ),
288 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
289 Err(ProviderError::Authentication(format!("Authentication failed. Please ensure your API keys are valid and have the required permissions. \
290 Status: {}. Response: {:?}", final_status, payload )))
291 }
292 StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND => {
293 let mut error_msg = "Unknown error".to_string();
294 if let Some(payload) = &payload {
295 if let Some(error) = payload.get("error") {
296 error_msg = error.get("message").and_then(|m| m.as_str()).unwrap_or("Unknown error").to_string();
297 let error_status = error.get("status").and_then(|s| s.as_str()).unwrap_or("Unknown status");
298 if error_status == "INVALID_ARGUMENT" && error_msg.to_lowercase().contains("exceeds") {
299 return Err(ProviderError::ContextLengthExceeded(error_msg.to_string()));
300 }
301 }
302 }
303 tracing::debug!(
304 "{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
305 );
306 Err(ProviderError::RequestFailed(format!("Request failed with status: {}. Message: {}", final_status, error_msg)))
307 }
308 StatusCode::TOO_MANY_REQUESTS => {
309 let retry_delay = payload.as_ref().and_then(parse_google_retry_delay);
310 Err(ProviderError::RateLimitExceeded {
311 details: format!("{:?}", payload),
312 retry_delay,
313 })
314 }
315 _ if final_status.is_server_error() => Err(ProviderError::ServerError(
316 format_server_error_message(final_status, payload.as_ref()),
317 )),
318 _ => {
319 tracing::debug!(
320 "{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
321 );
322 Err(ProviderError::RequestFailed(format!("Request failed with status: {}", final_status)))
323 }
324 }
325}
326
327pub fn sanitize_function_name(name: &str) -> String {
328 static RE: OnceLock<Regex> = OnceLock::new();
329 let re = RE.get_or_init(|| Regex::new(r"[^a-zA-Z0-9_-]").unwrap());
330 re.replace_all(name, "_").to_string()
331}
332
333pub fn is_valid_function_name(name: &str) -> bool {
334 static RE: OnceLock<Regex> = OnceLock::new();
335 let re = RE.get_or_init(|| Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
336 re.is_match(name)
337}
338
339pub fn get_model(data: &Value) -> String {
341 if let Some(model) = data.get("model") {
342 if let Some(model_str) = model.as_str() {
343 model_str.to_string()
344 } else {
345 "Unknown".to_string()
346 }
347 } else {
348 "Unknown".to_string()
349 }
350}
351
352fn is_image_file(path: &Path) -> bool {
354 if let Ok(mut file) = std::fs::File::open(path) {
355 let mut buffer = [0u8; 8]; if file.read(&mut buffer).is_ok() {
357 return match &buffer[0..4] {
359 [0x89, 0x50, 0x4E, 0x47] => true,
361 [0xFF, 0xD8, 0xFF, _] => true,
363 [0x47, 0x49, 0x46, 0x38] => true,
365 _ => false,
366 };
367 }
368 }
369 false
370}
371
372pub fn detect_image_path(text: &str) -> Option<&str> {
374 let extensions = [".png", ".jpg", ".jpeg"];
376
377 for word in text.split_whitespace() {
379 if extensions
380 .iter()
381 .any(|ext| word.to_lowercase().ends_with(ext))
382 {
383 let path = Path::new(word);
384 if path.is_absolute() && path.is_file() {
386 if is_image_file(path) {
388 return Some(word);
389 }
390 }
391 }
392 }
393 None
394}
395
396pub fn load_image_file(path: &str) -> Result<ImageContent, ProviderError> {
398 let path = Path::new(path);
399
400 if !is_image_file(path) {
402 return Err(ProviderError::RequestFailed(
403 "File is not a valid image".to_string(),
404 ));
405 }
406
407 let bytes = std::fs::read(path)
409 .map_err(|e| ProviderError::RequestFailed(format!("Failed to read image file: {}", e)))?;
410
411 let mime_type = match path.extension().and_then(|e| e.to_str()) {
413 Some(ext) => match ext.to_lowercase().as_str() {
414 "png" => "image/png",
415 "jpg" | "jpeg" => "image/jpeg",
416 _ => {
417 return Err(ProviderError::RequestFailed(
418 "Unsupported image format".to_string(),
419 ))
420 }
421 },
422 None => {
423 return Err(ProviderError::RequestFailed(
424 "Unknown image format".to_string(),
425 ))
426 }
427 };
428
429 let data = base64::prelude::BASE64_STANDARD.encode(&bytes);
431
432 Ok(RawImageContent {
433 mime_type: mime_type.to_string(),
434 data,
435 meta: None,
436 }
437 .no_annotation())
438}
439
440pub fn unescape_json_values(value: &Value) -> Value {
441 let mut cloned = value.clone();
442 unescape_json_values_in_place(&mut cloned);
443 cloned
444}
445
446fn unescape_json_values_in_place(value: &mut Value) {
447 match value {
448 Value::Object(map) => {
449 for v in map.values_mut() {
450 unescape_json_values_in_place(v);
451 }
452 }
453 Value::Array(arr) => {
454 for v in arr.iter_mut() {
455 unescape_json_values_in_place(v);
456 }
457 }
458 Value::String(s) => {
459 if s.contains('\\') {
460 *s = s
461 .replace("\\\\n", "\n")
462 .replace("\\\\t", "\t")
463 .replace("\\\\r", "\r")
464 .replace("\\\\\"", "\"")
465 .replace("\\n", "\n")
466 .replace("\\t", "\t")
467 .replace("\\r", "\r")
468 .replace("\\\"", "\"");
469 }
470 }
471 _ => {}
472 }
473}
474
475pub struct RequestLog {
476 writer: Option<BufWriter<File>>,
477 temp_path: PathBuf,
478}
479
480pub const LOGS_TO_KEEP: usize = 10;
481
482impl RequestLog {
483 pub fn start<Payload>(model_config: &ModelConfig, payload: &Payload) -> Result<Self>
484 where
485 Payload: Serialize,
486 {
487 let logs_dir = Paths::in_state_dir("logs");
488
489 let request_id = Uuid::new_v4();
490 let temp_name = format!("llm_request.{request_id}.jsonl");
491 let temp_path = logs_dir.join(PathBuf::from(temp_name));
492
493 let mut writer = BufWriter::new(
494 File::options()
495 .write(true)
496 .create(true)
497 .truncate(true)
498 .open(&temp_path)?,
499 );
500
501 let data = serde_json::json!({
502 "model_config": model_config,
503 "input": payload,
504 });
505 writeln!(writer, "{}", serde_json::to_string(&data)?)?;
506
507 Ok(Self {
508 writer: Some(writer),
509 temp_path,
510 })
511 }
512
513 fn write_json(&mut self, line: &serde_json::Value) -> Result<()> {
514 let writer = self
515 .writer
516 .as_mut()
517 .ok_or_else(|| anyhow!("logger is finished"))?;
518 writeln!(writer, "{}", serde_json::to_string(line)?)?;
519 Ok(())
520 }
521
522 pub fn error<E>(&mut self, error: E) -> Result<()>
523 where
524 E: Display,
525 {
526 self.write_json(&serde_json::json!({
527 "error": format!("{}", error),
528 }))
529 }
530
531 pub fn write<Payload>(&mut self, data: &Payload, usage: Option<&Usage>) -> Result<()>
532 where
533 Payload: Serialize,
534 {
535 self.write_json(&serde_json::json!({
536 "data": data,
537 "usage": usage,
538 }))
539 }
540
541 fn finish(&mut self) -> Result<()> {
542 if let Some(mut writer) = self.writer.take() {
543 writer.flush()?;
544 let logs_dir = Paths::in_state_dir("logs");
545 let log_path = |i| logs_dir.join(format!("llm_request.{}.jsonl", i));
546
547 for i in (0..LOGS_TO_KEEP - 1).rev() {
548 let _ = std::fs::rename(log_path(i), log_path(i + 1));
549 }
550
551 std::fs::rename(&self.temp_path, log_path(0))?;
552 }
553 Ok(())
554 }
555}
556
557impl Drop for RequestLog {
558 fn drop(&mut self) {
559 if std::thread::panicking() {
560 return;
561 }
562 let _ = self.finish();
563 }
564}
565
566pub fn safely_parse_json(s: &str) -> Result<serde_json::Value, serde_json::Error> {
575 match serde_json::from_str(s) {
577 Ok(value) => Ok(value),
578 Err(_) => {
579 let escaped = json_escape_control_chars_in_string(s);
581 serde_json::from_str(&escaped)
582 }
583 }
584}
585
586pub fn json_escape_control_chars_in_string(s: &str) -> String {
601 let mut r = String::with_capacity(s.len()); for c in s.chars() {
603 match c {
604 '\u{0000}'..='\u{001F}' => {
606 match c {
607 '\u{0008}' => r.push_str("\\b"), '\u{000C}' => r.push_str("\\f"), '\n' => r.push_str("\\n"), '\r' => r.push_str("\\r"), '\t' => r.push_str("\\t"), _ => {
615 r.push_str(&format!("\\u{:04x}", c as u32));
616 }
617 }
618 }
619 _ => r.push(c),
627 }
628 }
629 r
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635 use serde_json::json;
636
637 #[test]
638 fn test_detect_image_path() {
639 let temp_dir = tempfile::tempdir().unwrap();
641 let png_path = temp_dir.path().join("test.png");
642 let png_data = [
643 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, ];
647 std::fs::write(&png_path, png_data).unwrap();
648 let png_path_str = png_path.to_str().unwrap();
649
650 let fake_png_path = temp_dir.path().join("fake.png");
652 std::fs::write(&fake_png_path, b"not a real png").unwrap();
653
654 let text = format!("Here is an image {}", png_path_str);
656 assert_eq!(detect_image_path(&text), Some(png_path_str));
657
658 let text = format!("Here is a fake image {}", fake_png_path.to_str().unwrap());
660 assert_eq!(detect_image_path(&text), None);
661
662 let text = "Here is a fake.png that doesn't exist";
664 assert_eq!(detect_image_path(text), None);
665
666 let text = "Here is a file.txt";
668 assert_eq!(detect_image_path(text), None);
669
670 let text = "Here is a relative/path/image.png";
672 assert_eq!(detect_image_path(text), None);
673 }
674
675 #[test]
676 fn test_load_image_file() {
677 let temp_dir = tempfile::tempdir().unwrap();
679 let png_path = temp_dir.path().join("test.png");
680 let png_data = [
681 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, ];
685 std::fs::write(&png_path, png_data).unwrap();
686 let png_path_str = png_path.to_str().unwrap();
687
688 let fake_png_path = temp_dir.path().join("fake.png");
690 std::fs::write(&fake_png_path, b"not a real png").unwrap();
691 let fake_png_path_str = fake_png_path.to_str().unwrap();
692
693 let result = load_image_file(png_path_str);
695 assert!(result.is_ok());
696 let image = result.unwrap();
697 assert_eq!(image.mime_type, "image/png");
698
699 let result = load_image_file(fake_png_path_str);
701 assert!(result.is_err());
702 assert!(result
703 .unwrap_err()
704 .to_string()
705 .contains("not a valid image"));
706
707 let result = load_image_file("nonexistent.png");
709 assert!(result.is_err());
710
711 let gif_path = temp_dir.path().join("test.gif");
713 let gif_data = [0x47, 0x49, 0x46, 0x38, 0x39, 0x61];
715 std::fs::write(&gif_path, gif_data).unwrap();
716 let gif_path_str = gif_path.to_str().unwrap();
717
718 let result = load_image_file(gif_path_str);
720 assert!(result.is_err());
721 assert!(result
722 .unwrap_err()
723 .to_string()
724 .contains("Unsupported image format"));
725 }
726
727 #[test]
728 fn test_sanitize_function_name() {
729 assert_eq!(sanitize_function_name("hello-world"), "hello-world");
730 assert_eq!(sanitize_function_name("hello world"), "hello_world");
731 assert_eq!(sanitize_function_name("hello@world"), "hello_world");
732 }
733
734 #[test]
735 fn test_is_valid_function_name() {
736 assert!(is_valid_function_name("hello-world"));
737 assert!(is_valid_function_name("hello_world"));
738 assert!(!is_valid_function_name("hello world"));
739 assert!(!is_valid_function_name("hello@world"));
740 }
741
742 #[test]
743 fn unescape_json_values_with_object() {
744 let value = json!({"text": "Hello\\nWorld"});
745 let unescaped_value = unescape_json_values(&value);
746 assert_eq!(unescaped_value, json!({"text": "Hello\nWorld"}));
747 }
748
749 #[test]
750 fn unescape_json_values_with_array() {
751 let value = json!(["Hello\\nWorld", "Goodbye\\tWorld"]);
752 let unescaped_value = unescape_json_values(&value);
753 assert_eq!(unescaped_value, json!(["Hello\nWorld", "Goodbye\tWorld"]));
754 }
755
756 #[test]
757 fn unescape_json_values_with_string() {
758 let value = json!("Hello\\nWorld");
759 let unescaped_value = unescape_json_values(&value);
760 assert_eq!(unescaped_value, json!("Hello\nWorld"));
761 }
762
763 #[test]
764 fn unescape_json_values_with_mixed_content() {
765 let value = json!({
766 "text": "Hello\\nWorld\\\\n!",
767 "array": ["Goodbye\\tWorld", "See you\\rlater"],
768 "nested": {
769 "inner_text": "Inner\\\"Quote\\\""
770 }
771 });
772 let unescaped_value = unescape_json_values(&value);
773 assert_eq!(
774 unescaped_value,
775 json!({
776 "text": "Hello\nWorld\n!",
777 "array": ["Goodbye\tWorld", "See you\rlater"],
778 "nested": {
779 "inner_text": "Inner\"Quote\""
780 }
781 })
782 );
783 }
784
785 #[test]
786 fn unescape_json_values_with_no_escapes() {
787 let value = json!({"text": "Hello World"});
788 let unescaped_value = unescape_json_values(&value);
789 assert_eq!(unescaped_value, json!({"text": "Hello World"}));
790 }
791
792 #[test]
793 fn test_is_google_model() {
794 let test_cases = vec![
796 (json!({ "model": "google_gemini" }), true),
798 (json!({ "model": "microsoft_bing" }), false),
799 (json!({ "model": "" }), false),
800 (json!({}), false),
801 (json!({ "model": "Google_XYZ" }), true),
802 (json!({ "model": "google_abc" }), true),
803 ];
804
805 for (payload, expected_result) in test_cases {
807 assert_eq!(is_google_model(&payload), expected_result);
808 }
809 }
810
811 #[test]
812 fn test_get_google_final_status_success() {
813 let status = StatusCode::OK;
814 let payload = json!({});
815 let result = get_google_final_status(status, Some(&payload));
816 assert_eq!(result, StatusCode::OK);
817 }
818
819 #[test]
820 fn test_get_google_final_status_with_error_code() {
821 let test_cases = vec![
823 (200, None, StatusCode::OK),
825 (429, Some(StatusCode::OK), StatusCode::TOO_MANY_REQUESTS),
826 (400, Some(StatusCode::OK), StatusCode::BAD_REQUEST),
827 (401, Some(StatusCode::OK), StatusCode::UNAUTHORIZED),
828 (403, Some(StatusCode::OK), StatusCode::FORBIDDEN),
829 (404, Some(StatusCode::OK), StatusCode::NOT_FOUND),
830 (500, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
831 (503, Some(StatusCode::OK), StatusCode::SERVICE_UNAVAILABLE),
832 (999, Some(StatusCode::OK), StatusCode::INTERNAL_SERVER_ERROR),
833 (500, Some(StatusCode::BAD_REQUEST), StatusCode::BAD_REQUEST),
834 (
835 404,
836 Some(StatusCode::INTERNAL_SERVER_ERROR),
837 StatusCode::INTERNAL_SERVER_ERROR,
838 ),
839 ];
840
841 for (error_code, status, expected_status) in test_cases {
842 let payload = if let Some(_status) = status {
843 json!({
844 "error": {
845 "code": error_code,
846 "message": "Error message"
847 }
848 })
849 } else {
850 json!({})
851 };
852
853 let result = get_google_final_status(status.unwrap_or(StatusCode::OK), Some(&payload));
854 assert_eq!(result, expected_status);
855 }
856 }
857
858 #[test]
859 fn test_safely_parse_json() {
860 let valid_json = r#"{"key1": "value1","key2": "value2"}"#;
862 let result = safely_parse_json(valid_json).unwrap();
863 assert_eq!(result["key1"], "value1");
864 assert_eq!(result["key2"], "value2");
865
866 let invalid_json = "{\"key1\": \"value1\n\",\"key2\": \"value2\"}";
868 let result = safely_parse_json(invalid_json).unwrap();
869 assert_eq!(result["key1"], "value1\n");
870 assert_eq!(result["key2"], "value2");
871
872 let good_json = r#"{"test": "value"}"#;
874 let result = safely_parse_json(good_json).unwrap();
875 assert_eq!(result["test"], "value");
876
877 let broken_json = r#"{"key": "unclosed_string"#;
879 assert!(safely_parse_json(broken_json).is_err());
880
881 let empty_json = "{}";
883 let result = safely_parse_json(empty_json).unwrap();
884 assert!(result.as_object().unwrap().is_empty());
885
886 let escaped_json = r#"{"key": "value with\nnewline"}"#;
888 let result = safely_parse_json(escaped_json).unwrap();
889 assert_eq!(result["key"], "value with\nnewline");
890 }
891
892 #[test]
893 fn test_json_escape_control_chars_in_string() {
894 assert_eq!(
896 json_escape_control_chars_in_string("Hello\nWorld"),
897 "Hello\\nWorld"
898 );
899 assert_eq!(
900 json_escape_control_chars_in_string("Hello\tWorld"),
901 "Hello\\tWorld"
902 );
903 assert_eq!(
904 json_escape_control_chars_in_string("Hello\rWorld"),
905 "Hello\\rWorld"
906 );
907
908 assert_eq!(
910 json_escape_control_chars_in_string("Hello\n\tWorld\r"),
911 "Hello\\n\\tWorld\\r"
912 );
913
914 assert_eq!(
916 json_escape_control_chars_in_string("Hello \"World\""),
917 "Hello \"World\""
918 );
919 assert_eq!(
920 json_escape_control_chars_in_string("Hello\\World"),
921 "Hello\\World"
922 );
923
924 assert_eq!(
926 json_escape_control_chars_in_string("{\"message\": \"Hello\nWorld\"}"),
927 "{\"message\": \"Hello\\nWorld\"}"
928 );
929
930 assert_eq!(
932 json_escape_control_chars_in_string("Hello World"),
933 "Hello World"
934 );
935
936 assert_eq!(
938 json_escape_control_chars_in_string("Hello\u{0001}World"),
939 "Hello\\u0001World"
940 );
941 }
942
943 #[test]
944 fn test_parse_google_retry_delay() {
945 let payload = json!({
946 "error": {
947 "details": [
948 {
949 "@type": "type.googleapis.com/google.rpc.RetryInfo",
950 "retryDelay": "42s"
951 }
952 ]
953 }
954 });
955 assert_eq!(
956 parse_google_retry_delay(&payload),
957 Some(Duration::from_secs(42))
958 );
959 }
960}