use super::types::{ContentBlock, CreateMessageResponse, Usage};
use crate::error::{Result, SofosError};
use colored::Colorize;
use rand::RngExt;
use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue};
use std::future::Future;
use std::time::Duration;
pub const REQUEST_TIMEOUT: Duration = Duration::from_secs(1800);
pub const MORPH_REQUEST_TIMEOUT: Duration = Duration::from_secs(600);
pub const MAX_RETRIES: u32 = 2;
pub const INITIAL_RETRY_DELAY_MS: u64 = 1000;
const JITTER_FACTOR: f64 = 0.3;
const DEFAULT_CONTENT_TYPE: &str = "application/json";
const RESPONSE_TYPE_MESSAGE: &str = "message";
const ROLE_ASSISTANT: &str = "assistant";
fn merge_default_headers(provider_headers: HeaderMap) -> HeaderMap {
let mut headers = provider_headers;
headers
.entry(CONTENT_TYPE)
.or_insert(HeaderValue::from_static(DEFAULT_CONTENT_TYPE));
headers
}
pub fn build_http_client(
provider_headers: HeaderMap,
timeout: Duration,
) -> Result<reqwest::Client> {
reqwest::Client::builder()
.default_headers(merge_default_headers(provider_headers))
.timeout(timeout)
.build()
.map_err(|e| SofosError::Config(format!("Failed to create HTTP client: {}", e)))
}
pub fn build_message_response(
id: String,
model: String,
content: Vec<ContentBlock>,
stop_reason: Option<String>,
input_tokens: u32,
output_tokens: u32,
) -> CreateMessageResponse {
CreateMessageResponse {
_id: id,
_response_type: RESPONSE_TYPE_MESSAGE.to_string(),
_role: ROLE_ASSISTANT.to_string(),
content,
_model: model,
stop_reason,
usage: Usage {
input_tokens,
output_tokens,
},
}
}
pub async fn check_api_connectivity(
client: &reqwest::Client,
base_url: &str,
provider_name: &str,
status_url: &str,
) -> Result<()> {
match tokio::time::timeout(Duration::from_secs(5), client.head(base_url).send()).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(SofosError::NetworkError(format!(
"Cannot reach {} API. Please check:\n \
1. Your internet connection\n \
2. Firewall/proxy settings\n \
3. API status at {}\n\
Original error: {}",
provider_name, status_url, e
))),
Err(_) => Err(SofosError::NetworkError(
"Connection timeout. Please check your network connection.".into(),
)),
}
}
const UNPARSEABLE_ARGS_PREVIEW_BYTES: usize = 500;
pub fn parse_tool_arguments(name: &str, args: &str) -> serde_json::Value {
if name == crate::tools::tool_name::ToolName::MorphEditFile.as_str() {
return serde_json::from_str(args)
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
}
if let Some(v) = try_parse_json_object(args) {
return v;
}
let trimmed = args.trim();
if trimmed != args {
if let Some(v) = try_parse_json_object(trimmed) {
return v;
}
}
let no_trailing_commas = strip_trailing_commas_outside_strings(trimmed);
if no_trailing_commas != trimmed {
if let Some(v) = try_parse_json_object(&no_trailing_commas) {
return v;
}
}
let escaped = escape_control_chars_in_json_strings(&no_trailing_commas);
if escaped != no_trailing_commas {
if let Some(v) = try_parse_json_object(&escaped) {
return v;
}
}
if trimmed.starts_with('{') {
let mut candidate = escape_control_chars_in_json_strings(trimmed);
if string_is_open(&candidate) {
candidate.push('"');
}
candidate = candidate.trim_end_matches(',').to_string();
if !candidate.ends_with('}') {
candidate.push('}');
}
if let Some(v) = try_parse_json_object(&candidate) {
return v;
}
}
let preview_end = truncate_at_char_boundary(args, UNPARSEABLE_ARGS_PREVIEW_BYTES);
eprintln!(
" \x1b[33m⚠\x1b[0m Failed to parse tool arguments as JSON for {}: {}",
name,
&args[..preview_end]
);
serde_json::json!({"raw_arguments": args})
}
fn try_parse_json_object(s: &str) -> Option<serde_json::Value> {
let v: serde_json::Value = serde_json::from_str(s).ok()?;
if let serde_json::Value::String(inner) = &v {
let inner_trim = inner.trim();
if inner_trim.starts_with('{') || inner_trim.starts_with('[') {
if let Ok(unwrapped) = serde_json::from_str::<serde_json::Value>(inner) {
return Some(unwrapped);
}
}
}
Some(v)
}
fn strip_trailing_commas_outside_strings(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let mut in_string = false;
let mut prev_backslash = false;
let mut chars = s.chars().peekable();
while let Some(ch) = chars.next() {
if in_string {
out.push(ch);
if prev_backslash {
prev_backslash = false;
} else if ch == '\\' {
prev_backslash = true;
} else if ch == '"' {
in_string = false;
}
continue;
}
match ch {
'"' => {
in_string = true;
out.push(ch);
}
',' if matches!(chars.peek(), Some('}') | Some(']')) => {
}
_ => out.push(ch),
}
}
out
}
fn escape_control_chars_in_json_strings(s: &str) -> String {
let mut out = String::with_capacity(s.len());
let mut in_string = false;
let mut prev_backslash = false;
for ch in s.chars() {
if in_string {
if prev_backslash {
match ch {
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
_ => out.push(ch),
}
prev_backslash = false;
continue;
}
match ch {
'\\' => {
out.push(ch);
prev_backslash = true;
}
'"' => {
out.push(ch);
in_string = false;
}
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
_ => out.push(ch),
}
} else {
if ch == '"' {
in_string = true;
}
out.push(ch);
}
}
out
}
fn string_is_open(s: &str) -> bool {
let mut in_string = false;
let mut prev_backslash = false;
for ch in s.chars() {
if in_string {
if prev_backslash {
prev_backslash = false;
continue;
}
match ch {
'\\' => prev_backslash = true,
'"' => in_string = false,
_ => {}
}
} else if ch == '"' {
in_string = true;
}
}
in_string
}
pub fn truncate_at_char_boundary(s: &str, max_bytes: usize) -> usize {
if max_bytes >= s.len() {
return s.len();
}
let mut i = max_bytes;
while i > 0 && !s.is_char_boundary(i) {
i -= 1;
}
i
}
#[derive(Debug)]
pub enum ApiCallError {
Transport(reqwest::Error),
ServerError {
status: reqwest::StatusCode,
body: String,
},
ClientError {
status: reqwest::StatusCode,
body: String,
},
}
impl ApiCallError {
fn is_retryable(&self) -> bool {
matches!(self, Self::ServerError { .. })
}
fn describe(&self) -> String {
match self {
Self::Transport(e) => format!("Request failed: {}", e),
Self::ServerError { status, .. } => format!("Server error {}", status),
Self::ClientError { status, .. } => format!("Client error {}", status),
}
}
}
pub async fn classify_response(
response: reqwest::Response,
) -> std::result::Result<reqwest::Response, ApiCallError> {
let status = response.status();
if status.is_success() {
return Ok(response);
}
let body = response.text().await.unwrap_or_default();
if status.is_server_error() {
Err(ApiCallError::ServerError { status, body })
} else {
Err(ApiCallError::ClientError { status, body })
}
}
pub async fn send_classified(
request: reqwest::RequestBuilder,
) -> std::result::Result<reqwest::Response, ApiCallError> {
let response = request.send().await.map_err(ApiCallError::Transport)?;
classify_response(response).await
}
pub async fn send_once(
service_name: &str,
request: reqwest::RequestBuilder,
) -> Result<reqwest::Response> {
send_classified(request)
.await
.map_err(|e| api_call_error_to_sofos(service_name, 1, e))
}
fn api_call_error_to_sofos(service_name: &str, attempts: u32, e: ApiCallError) -> SofosError {
match e {
ApiCallError::Transport(err) => SofosError::NetworkError(format!(
"{} request failed after {} attempt(s): {}",
service_name, attempts, err
)),
ApiCallError::ServerError { status, body } | ApiCallError::ClientError { status, body } => {
SofosError::Api(format!(
"{} request failed with status {} after {} attempt(s): {}",
service_name, status, attempts, body
))
}
}
}
pub async fn with_retries<F, Fut, T>(service_name: &str, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: Future<Output = std::result::Result<T, ApiCallError>>,
{
let mut retry_delay = Duration::from_millis(INITIAL_RETRY_DELAY_MS);
for attempt in 0..=MAX_RETRIES {
if attempt > 0 {
let jitter = rand::rng().random_range(0.0..JITTER_FACTOR);
let jittered_delay = retry_delay.mul_f64(1.0 + jitter);
tracing::warn!(
service = service_name,
attempt = attempt,
max_retries = MAX_RETRIES,
delay_ms = jittered_delay.as_millis() as u64,
"Retrying API request after server error"
);
eprintln!(
" {} server error, retrying in {:?}... (attempt {}/{})",
format!("{}:", service_name).bright_yellow(),
jittered_delay,
attempt,
MAX_RETRIES
);
tokio::time::sleep(jittered_delay).await;
retry_delay *= 2;
}
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
let retryable = e.is_retryable();
if attempt < MAX_RETRIES && retryable {
continue;
}
let attempts = attempt + 1;
tracing::error!(
service = service_name,
attempts = attempts,
reason = %e.describe(),
retryable = retryable,
"API request failed permanently"
);
return Err(api_call_error_to_sofos(service_name, attempts, e));
}
}
}
Err(SofosError::NetworkError(format!(
"Unknown {} error",
service_name
)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn api_call_error_is_retryable_only_for_server_error() {
let server = ApiCallError::ServerError {
status: reqwest::StatusCode::INTERNAL_SERVER_ERROR,
body: String::new(),
};
let client = ApiCallError::ClientError {
status: reqwest::StatusCode::BAD_REQUEST,
body: String::new(),
};
assert!(server.is_retryable());
assert!(!client.is_retryable());
}
#[tokio::test]
async fn with_retries_retries_server_error_then_succeeds() {
use std::sync::atomic::{AtomicU32, Ordering};
let attempts = AtomicU32::new(0);
let result: Result<&'static str> = with_retries("Test", || {
let n = attempts.fetch_add(1, Ordering::SeqCst);
async move {
if n < 2 {
Err(ApiCallError::ServerError {
status: reqwest::StatusCode::BAD_GATEWAY,
body: "retry me".into(),
})
} else {
Ok("done")
}
}
})
.await;
assert_eq!(result.unwrap(), "done");
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn with_retries_does_not_retry_client_error() {
use std::sync::atomic::{AtomicU32, Ordering};
let attempts = AtomicU32::new(0);
let result: Result<&'static str> = with_retries("Test", || {
attempts.fetch_add(1, Ordering::SeqCst);
async move {
Err(ApiCallError::ClientError {
status: reqwest::StatusCode::BAD_REQUEST,
body: "nope".into(),
})
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[test]
fn merge_default_headers_adds_content_type_when_absent() {
let merged = merge_default_headers(HeaderMap::new());
assert_eq!(merged.get(CONTENT_TYPE).unwrap(), DEFAULT_CONTENT_TYPE);
}
#[test]
fn merge_default_headers_respects_caller_content_type() {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/vnd.api+json"),
);
let merged = merge_default_headers(headers);
assert_eq!(
merged.get(CONTENT_TYPE).unwrap(),
"application/vnd.api+json"
);
}
#[test]
fn merge_default_headers_preserves_provider_auth_headers() {
let mut headers = HeaderMap::new();
headers.insert("x-api-key", HeaderValue::from_static("secret"));
headers.insert("anthropic-version", HeaderValue::from_static("2023-06-01"));
let merged = merge_default_headers(headers);
assert_eq!(merged.get("x-api-key").unwrap(), "secret");
assert_eq!(merged.get("anthropic-version").unwrap(), "2023-06-01");
assert_eq!(merged.get(CONTENT_TYPE).unwrap(), DEFAULT_CONTENT_TYPE);
}
#[test]
fn build_http_client_succeeds_with_empty_headers() {
assert!(build_http_client(HeaderMap::new(), REQUEST_TIMEOUT).is_ok());
}
#[test]
fn build_message_response_populates_constant_fields() {
let r = build_message_response(
"id-42".into(),
"test-model".into(),
vec![],
Some("max_tokens".into()),
100,
50,
);
assert_eq!(r._id, "id-42");
assert_eq!(r._model, "test-model");
assert_eq!(r._role, "assistant");
assert_eq!(r._response_type, "message");
assert_eq!(r.stop_reason.as_deref(), Some("max_tokens"));
assert_eq!(r.usage.input_tokens, 100);
assert_eq!(r.usage.output_tokens, 50);
assert!(r.content.is_empty());
}
#[test]
fn test_truncate_at_char_boundary_ascii() {
assert_eq!(truncate_at_char_boundary("hello world", 5), 5);
assert_eq!(truncate_at_char_boundary("hello", 10), 5);
assert_eq!(truncate_at_char_boundary("hello", 0), 0);
assert_eq!(truncate_at_char_boundary("", 5), 0);
}
#[test]
fn test_truncate_at_char_boundary_multibyte() {
let s = "ab─cd";
assert_eq!(s.len(), 7); assert_eq!(truncate_at_char_boundary(s, 3), 2);
assert_eq!(truncate_at_char_boundary(s, 4), 2);
assert_eq!(truncate_at_char_boundary(s, 5), 5);
}
#[test]
fn test_truncate_at_char_boundary_emoji() {
let s = "a🦀b";
assert_eq!(s.len(), 6); assert_eq!(truncate_at_char_boundary(s, 1), 1);
assert_eq!(truncate_at_char_boundary(s, 2), 1);
assert_eq!(truncate_at_char_boundary(s, 3), 1);
assert_eq!(truncate_at_char_boundary(s, 4), 1);
assert_eq!(truncate_at_char_boundary(s, 5), 5);
}
use crate::tools::tool_name::ToolName;
#[test]
fn parse_args_valid_object_round_trips() {
let v = parse_tool_arguments("read_file", r#"{"path":"src/main.rs"}"#);
assert_eq!(v["path"], "src/main.rs");
}
#[test]
fn parse_args_repairs_trailing_comma() {
let v = parse_tool_arguments("read_file", r#"{"path":"src/main.rs",}"#);
assert_eq!(v["path"], "src/main.rs");
}
#[test]
fn parse_args_repairs_missing_closing_brace() {
let v = parse_tool_arguments("read_file", r#"{"path":"src/main.rs""#);
assert_eq!(v["path"], "src/main.rs");
}
#[test]
fn parse_args_unrepairable_falls_back_to_raw_arguments() {
let v = parse_tool_arguments("read_file", "not json at all");
assert_eq!(v["raw_arguments"], "not json at all");
}
#[test]
fn parse_args_escapes_literal_newline_in_string_value() {
let raw = "{\"path\":\"foo.md\",\"content\":\"line1\nline2\nend\"}";
let v = parse_tool_arguments("write_file", raw);
assert_eq!(v["path"], "foo.md");
assert_eq!(v["content"], "line1\nline2\nend");
}
#[test]
fn parse_args_escapes_newline_in_unicode_content() {
let raw = "{\"content\":\"# Синергията\nмежду Божия промисъл\",\"path\":\"doc.md\"}";
let v = parse_tool_arguments("write_file", raw);
assert_eq!(v["path"], "doc.md");
assert!(v["content"].as_str().unwrap().contains("Синергията"));
}
#[test]
fn parse_args_recovers_truncated_string_mid_value() {
let raw = "{\"path\":\"foo.md\",\"content\":\"hello\nworld interrupt";
let v = parse_tool_arguments("write_file", raw);
assert_eq!(v["path"], "foo.md");
assert!(v["content"].as_str().unwrap().contains("hello"));
}
#[test]
fn parse_args_unwraps_double_encoded_object() {
let raw = r#""{\"path\":\"foo.rs\"}""#;
let v = parse_tool_arguments("read_file", raw);
assert_eq!(v["path"], "foo.rs");
}
#[test]
fn parse_args_morph_edit_strict_returns_empty_object_on_failure() {
let v = parse_tool_arguments(
ToolName::MorphEditFile.as_str(),
r#"{"target_filepath":"src/lib.rs","code_edit":"fn x() { let y = [1,2,"#,
);
assert!(v.is_object());
assert_eq!(v.as_object().unwrap().len(), 0);
}
#[test]
fn parse_args_morph_edit_valid_round_trips() {
let v = parse_tool_arguments(
ToolName::MorphEditFile.as_str(),
r#"{"target_filepath":"src/lib.rs","instructions":"add fn","code_edit":"fn x() {}"}"#,
);
assert_eq!(v["target_filepath"], "src/lib.rs");
assert_eq!(v["code_edit"], "fn x() {}");
}
#[test]
fn parse_args_trailing_comma_strip_respects_strings() {
let raw = r#"{"note":"list ends ,} here","path":"x.rs",}"#;
let v = parse_tool_arguments("write_file", raw);
assert_eq!(v["note"], "list ends ,} here");
assert_eq!(v["path"], "x.rs");
}
#[test]
fn parse_args_escapes_raw_lf_after_escaped_backslash() {
let raw = "{\"path\":\"a.md\",\"content\":\"pre\\\\\npost\"}";
let v = parse_tool_arguments("write_file", raw);
assert_eq!(v["path"], "a.md");
let content = v["content"].as_str().unwrap();
assert_eq!(
content, "pre\\\npost",
"decoded value should be `pre` + backslash + LF + `post`"
);
}
#[test]
fn parse_args_empty_string_falls_back_to_raw_arguments() {
let v = parse_tool_arguments("read_file", "");
assert_eq!(v["raw_arguments"], "");
}
#[test]
fn parse_args_whitespace_only_falls_back_to_raw_arguments() {
let v = parse_tool_arguments("read_file", " \n\t");
assert_eq!(v["raw_arguments"], " \n\t");
}
#[test]
fn parse_args_array_root_returned_as_is() {
let v = parse_tool_arguments("read_file", "[1,2,3]");
assert!(v.is_array());
}
#[test]
fn string_is_open_detects_unterminated_literal() {
assert!(string_is_open(r#"{"a":"b"#));
assert!(!string_is_open(r#"{"a":"b"}"#));
assert!(!string_is_open(r#"{"a":"b\""}"#));
assert!(string_is_open(r#"{"a":"b\""#)); }
#[test]
fn escape_control_chars_leaves_structural_whitespace_alone() {
let src = "{\n \"a\": \"b\"\n}";
assert_eq!(
escape_control_chars_in_json_strings(src),
"{\n \"a\": \"b\"\n}"
);
}
}