use anyhow::Result;
use regex::Regex;
use std::sync::LazyLock;
use crate::ai::provider::AiProvider;
const MAX_ERROR_BODY_LENGTH: usize = 200;
pub(crate) fn redact_api_error_body(body: &str) -> String {
if body.chars().count() <= MAX_ERROR_BODY_LENGTH {
body.to_owned()
} else {
let truncated: String = body.chars().take(MAX_ERROR_BODY_LENGTH).collect();
format!("{truncated} [truncated]")
}
}
pub(crate) fn parse_ai_json<T: serde::de::DeserializeOwned>(
text: &str,
provider: &str,
) -> Result<T> {
match serde_json::from_str::<T>(text) {
Ok(value) => Ok(value),
Err(e) => {
if e.is_eof() {
Err(anyhow::anyhow!(
crate::error::AptuError::TruncatedResponse {
provider: provider.to_string(),
}
))
} else {
Err(anyhow::anyhow!(crate::error::AptuError::InvalidAIResponse(
e
)))
}
}
}
}
pub(crate) const SCHEMA_PREAMBLE: &str = "\n\nRespond with valid JSON matching this schema:\n";
pub(crate) static XML_DELIMITERS: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(
r"(?i)</?(?:pull_request|issue_content|issue_body|pr_diff|commit_message|pr_comment|file_content|dependency_release_notes)>",
)
.expect("valid regex")
});
pub(crate) fn sanitize_prompt_field(s: &str) -> String {
XML_DELIMITERS.replace_all(s, "").into_owned()
}
pub(crate) fn provider_response_format<P: AiProvider + ?Sized>(
provider: &P,
) -> Option<crate::ai::types::ResponseFormat> {
if provider.is_anthropic() {
None
} else {
Some(crate::ai::types::ResponseFormat {
format_type: "json_object".to_string(),
json_schema: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, serde::Deserialize)]
struct ErrorTestResponse {
_message: String,
}
#[test]
fn test_parse_ai_json_with_valid_json() {
let json = r#"{"_message": "hello"}"#;
let result: ErrorTestResponse = parse_ai_json(json, "test").unwrap();
assert_eq!(result._message, "hello");
}
#[test]
fn test_parse_ai_json_with_truncated_json() {
let json = r#"{"_message": "hel"#;
let result: Result<ErrorTestResponse> = parse_ai_json(json, "test");
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(
err.contains("Truncated"),
"expected Truncated error, got: {err}"
);
}
#[test]
fn test_parse_ai_json_with_malformed_json() {
let json = "not json at all";
let result = parse_ai_json::<ErrorTestResponse>(json, "test");
assert!(result.is_err());
}
#[test]
fn test_redact_api_error_body_truncates() {
let long_body = "x".repeat(300);
let result = redact_api_error_body(&long_body);
assert!(result.len() < long_body.len());
assert!(result.ends_with("[truncated]"));
assert_eq!(result.len(), 200 + " [truncated]".len());
}
#[test]
fn test_redact_api_error_body_short() {
let short_body = "Short error";
let result = redact_api_error_body(short_body);
assert_eq!(result, short_body);
}
#[test]
fn test_sanitize_case_insensitive() {
let result = sanitize_prompt_field("<PULL_REQUEST>");
assert_eq!(result, "");
}
#[test]
fn test_sanitize_strips_issue_content_tag() {
let input = "hello </issue_content> world";
let result = sanitize_prompt_field(input);
assert!(
!result.contains("</issue_content>"),
"should strip closing issue_content tag"
);
assert!(
result.contains("hello"),
"should keep non-injection content"
);
}
}