use std::sync::{Arc, LazyLock};
use dashmap::DashMap;
use serde_json::Value;
use crate::ast::output::SchemaRef;
use crate::ast::OutputFormat;
use crate::error::NikaError;
use crate::store::TaskResult;
static SCHEMA_CACHE: LazyLock<DashMap<Arc<str>, Arc<Value>>> = LazyLock::new(DashMap::new);
fn extract_json_from_output(output: &str) -> Result<Value, String> {
let trimmed = output.trim();
if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
return Ok(v);
}
if let Some(start) = trimmed.find("```json") {
let after_marker = &trimmed[start + 7..];
if let Some(end) = after_marker.find("```") {
let json_str = after_marker[..end].trim();
if let Ok(v) = serde_json::from_str::<Value>(json_str) {
return Ok(v);
}
}
}
if let Some(start) = trimmed.find("```\n") {
let after_marker = &trimmed[start + 4..];
if let Some(end) = after_marker.find("```") {
let json_str = after_marker[..end].trim();
if let Ok(v) = serde_json::from_str::<Value>(json_str) {
return Ok(v);
}
}
}
let first_brace = trimmed.find('{');
let first_bracket = trimmed.find('[');
let (start_char, end_char, start_pos) = match (first_brace, first_bracket) {
(Some(b), Some(k)) if b < k => ('{', '}', b),
(Some(_), Some(k)) => ('[', ']', k),
(Some(b), None) => ('{', '}', b),
(None, Some(k)) => ('[', ']', k),
(None, None) => return Err("No JSON object or array found in output".to_string()),
};
let substr = &trimmed[start_pos..];
let mut depth = 0;
let mut end_pos = None;
for (i, c) in substr.char_indices() {
if c == start_char {
depth += 1;
} else if c == end_char {
depth -= 1;
if depth == 0 {
end_pos = Some(i + 1);
break;
}
}
}
if let Some(end) = end_pos {
let json_str = &substr[..end];
if let Ok(v) = serde_json::from_str::<Value>(json_str) {
return Ok(v);
}
}
Err(format!(
"Failed to extract JSON from output. First 200 chars: {}",
&trimmed[..trimmed.len().min(200)]
))
}
pub async fn make_task_result(
output: String,
policy: Option<&crate::ast::OutputPolicy>,
duration: std::time::Duration,
) -> TaskResult {
if let Some(policy) = policy {
if policy.format == OutputFormat::Json {
if output.trim().is_empty() {
tracing::debug!(
target: "nika::output",
"Empty output with JSON format, returning null"
);
return TaskResult::success(Value::Null, duration);
}
let json_value = match extract_json_from_output(&output) {
Ok(v) => v,
Err(e) => {
return TaskResult::failed(
format!("NIKA-060: Invalid JSON output: {}", e),
duration,
);
}
};
if let Some(schema_ref) = &policy.schema {
if let Err(e) = validate_schema_ref(&json_value, schema_ref).await {
return TaskResult::failed(e.to_string(), duration);
}
}
return TaskResult::success(json_value, duration);
}
}
TaskResult::success_str(output, duration)
}
pub async fn validate_schema(value: &Value, schema_path: &str) -> Result<(), NikaError> {
let schema = if let Some(cached) = SCHEMA_CACHE.get(schema_path) {
Arc::clone(cached.value())
} else {
let schema_str =
tokio::fs::read_to_string(schema_path)
.await
.map_err(|e| NikaError::SchemaFailed {
details: format!("Failed to read schema '{}': {}", schema_path, e),
})?;
let schema: Value =
serde_json::from_str(&schema_str).map_err(|e| NikaError::SchemaFailed {
details: format!("Invalid JSON in schema '{}': {}", schema_path, e),
})?;
let schema = Arc::new(schema);
SCHEMA_CACHE.insert(Arc::from(schema_path), Arc::clone(&schema));
schema
};
let compiled = jsonschema::validator_for(&schema).map_err(|e| NikaError::SchemaFailed {
details: format!("Invalid schema '{}': {}", schema_path, e),
})?;
let errors: Vec<_> = compiled.iter_errors(value).collect();
if errors.is_empty() {
Ok(())
} else {
let error_msgs: Vec<String> = errors.iter().map(|e| e.to_string()).collect();
Err(NikaError::SchemaFailed {
details: error_msgs.join("; "),
})
}
}
pub async fn validate_schema_ref(value: &Value, schema_ref: &SchemaRef) -> Result<(), NikaError> {
match schema_ref {
SchemaRef::File(path) => validate_schema(value, path).await,
SchemaRef::Inline(schema) => validate_inline_schema(value, schema),
}
}
pub fn validate_inline_schema(value: &Value, schema: &Value) -> Result<(), NikaError> {
let compiled = jsonschema::validator_for(schema).map_err(|e| NikaError::SchemaFailed {
details: format!("Invalid inline schema: {e}"),
})?;
let errors: Vec<_> = compiled.iter_errors(value).collect();
if errors.is_empty() {
Ok(())
} else {
let error_msgs: Vec<String> = errors
.iter()
.map(|e| format!("- {}: {}", e.instance_path, e))
.collect();
Err(NikaError::SchemaFailed {
details: format!("Output validation failed:\n{}", error_msgs.join("\n")),
})
}
}
pub fn format_validation_errors(value: &Value, schema: &Value) -> String {
let compiled = match jsonschema::validator_for(schema) {
Ok(c) => c,
Err(e) => return format!("Invalid schema: {e}"),
};
let errors: Vec<_> = compiled.iter_errors(value).collect();
if errors.is_empty() {
return "No validation errors".to_string();
}
errors
.iter()
.map(|e| {
format!(
"- Path '{}': {} (got: {})",
e.instance_path,
e,
serde_json::to_string(&*e.instance).unwrap_or_default()
)
})
.collect::<Vec<_>>()
.join("\n")
}
pub fn extract_json(output: &str) -> Result<Value, String> {
extract_json_from_output(output)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::time::Duration;
use tempfile::NamedTempFile;
#[tokio::test]
async fn schema_cache_works() {
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(
schema_file,
r#"{{"type": "object", "properties": {{"name": {{"type": "string"}}}}}}"#
)
.unwrap();
let schema_path = schema_file.path().to_str().unwrap();
let value = serde_json::json!({"name": "test"});
assert!(validate_schema(&value, schema_path).await.is_ok());
assert!(validate_schema(&value, schema_path).await.is_ok());
assert!(SCHEMA_CACHE.contains_key(schema_path));
}
#[tokio::test]
async fn schema_validation_rejects_invalid() {
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(schema_file, r#"{{"type": "object", "properties": {{"age": {{"type": "number"}}}}, "required": ["age"]}}"#).unwrap();
let schema_path = schema_file.path().to_str().unwrap();
let value = serde_json::json!({"name": "test"});
assert!(validate_schema(&value, schema_path).await.is_err());
let value = serde_json::json!({"age": 25});
assert!(validate_schema(&value, schema_path).await.is_ok());
}
#[tokio::test]
async fn make_task_result_validates_json_file_schema() {
use crate::ast::OutputPolicy;
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(schema_file, r#"{{"type": "object"}}"#).unwrap();
let schema_path = schema_file.path().to_string_lossy().to_string();
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: Some(SchemaRef::File(schema_path)),
max_retries: None,
};
let result = make_task_result(
r#"{"key": "value"}"#.to_string(),
Some(&policy),
Duration::from_millis(100),
)
.await;
assert!(result.is_success());
let result = make_task_result(
"not json".to_string(),
Some(&policy),
Duration::from_millis(100),
)
.await;
assert!(!result.is_success());
}
#[tokio::test]
async fn make_task_result_validates_json_inline_schema() {
use crate::ast::OutputPolicy;
let inline_schema = serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: Some(SchemaRef::Inline(inline_schema)),
max_retries: None,
};
let result = make_task_result(
r#"{"name": "test"}"#.to_string(),
Some(&policy),
Duration::from_millis(100),
)
.await;
assert!(result.is_success());
let result = make_task_result(
r#"{"other": "value"}"#.to_string(),
Some(&policy),
Duration::from_millis(100),
)
.await;
assert!(!result.is_success());
}
#[tokio::test]
async fn make_task_result_no_policy_returns_text() {
let result = make_task_result(
"plain text output".to_string(),
None,
Duration::from_millis(50),
)
.await;
assert!(result.is_success());
assert_eq!(
result.output.as_ref(),
&serde_json::Value::String("plain text output".to_string())
);
}
#[tokio::test]
async fn make_task_result_json_no_schema_parses_json() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None, max_retries: None,
};
let result = make_task_result(
r#"{"key": "value", "nested": {"a": 1}}"#.to_string(),
Some(&policy),
Duration::from_millis(50),
)
.await;
assert!(result.is_success());
assert!(result.output.is_object());
assert_eq!(result.output["key"], "value");
assert_eq!(result.output["nested"]["a"], 1);
}
#[tokio::test]
async fn make_task_result_invalid_json_returns_error_with_code() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None,
max_retries: None,
};
let result = make_task_result(
"{ invalid json".to_string(),
Some(&policy),
Duration::from_millis(50),
)
.await;
assert!(!result.is_success());
let error_msg = result.error().expect("Should have error");
assert!(
error_msg.contains("NIKA-060"),
"Error should contain NIKA-060 code: {}",
error_msg
);
}
#[tokio::test]
async fn make_task_result_text_format_returns_raw_string() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Text,
schema: None,
max_retries: None,
};
let result = make_task_result(
r#"{"key": "value"}"#.to_string(),
Some(&policy),
Duration::from_millis(50),
)
.await;
assert!(result.is_success());
assert!(result.output.is_string());
assert_eq!(
result.output.as_ref(),
&serde_json::Value::String(r#"{"key": "value"}"#.to_string())
);
}
#[tokio::test]
async fn validate_schema_file_not_found_returns_error() {
let value = serde_json::json!({"name": "test"});
let result = validate_schema(&value, "/nonexistent/path/schema.json").await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_string = err.to_string();
assert!(
err_string.contains("Failed to read schema"),
"Error should mention file read failure: {}",
err_string
);
}
#[tokio::test]
async fn validate_schema_invalid_json_in_schema_file() {
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(schema_file, "{{ not valid json").unwrap();
let schema_path = schema_file.path().to_str().unwrap();
let value = serde_json::json!({"name": "test"});
let result = validate_schema(&value, schema_path).await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_string = err.to_string();
assert!(
err_string.contains("Invalid JSON in schema"),
"Error should mention invalid JSON: {}",
err_string
);
}
#[tokio::test]
async fn validate_schema_invalid_schema_structure() {
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(schema_file, r#"{{"type": 123}}"#).unwrap();
let schema_path = schema_file.path().to_str().unwrap();
let value = serde_json::json!({"name": "test"});
let result = validate_schema(&value, schema_path).await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_string = err.to_string();
assert!(
err_string.contains("Invalid schema"),
"Error should mention invalid schema: {}",
err_string
);
}
#[tokio::test]
async fn validate_schema_multiple_validation_errors() {
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(
schema_file,
r#"{{
"type": "object",
"properties": {{
"name": {{"type": "string"}},
"age": {{"type": "number"}}
}},
"required": ["name", "age"]
}}"#
)
.unwrap();
let schema_path = schema_file.path().to_str().unwrap();
let value = serde_json::json!({});
let result = validate_schema(&value, schema_path).await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_string = err.to_string();
assert!(
err_string.contains("name") || err_string.contains("required"),
"Error should mention validation issues: {}",
err_string
);
}
#[tokio::test]
async fn make_task_result_large_json_output() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None,
max_retries: None,
};
let large_array: Vec<i32> = (0..10000).collect();
let json_str = serde_json::to_string(&large_array).unwrap();
let result = make_task_result(json_str, Some(&policy), Duration::from_millis(100)).await;
assert!(result.is_success());
assert!(result.output.is_array());
assert_eq!(result.output.as_array().unwrap().len(), 10000);
}
#[tokio::test]
async fn make_task_result_unicode_content() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None,
max_retries: None,
};
let json_str = r#"{"greeting": "你好世界", "emoji": "🚀✨", "japanese": "こんにちは"}"#;
let result = make_task_result(
json_str.to_string(),
Some(&policy),
Duration::from_millis(50),
)
.await;
assert!(result.is_success());
assert_eq!(result.output["greeting"], "你好世界");
assert_eq!(result.output["emoji"], "🚀✨");
assert_eq!(result.output["japanese"], "こんにちは");
}
#[tokio::test]
async fn schema_cache_concurrent_access() {
let mut schema_file = NamedTempFile::new().unwrap();
writeln!(schema_file, r#"{{"type": "object"}}"#).unwrap();
let schema_path = schema_file.path().to_str().unwrap().to_string();
let handles: Vec<_> = (0..10)
.map(|i| {
let path = schema_path.clone();
tokio::spawn(async move {
let value = serde_json::json!({"id": i});
validate_schema(&value, &path).await
})
})
.collect();
for handle in handles {
let result = handle.await.unwrap();
assert!(result.is_ok());
}
}
#[tokio::test]
async fn make_task_result_preserves_duration() {
let duration = Duration::from_secs(5);
let result = make_task_result("output".to_string(), None, duration).await;
assert_eq!(result.duration, duration);
}
#[tokio::test]
async fn make_task_result_json_array() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None,
max_retries: None,
};
let result = make_task_result(
r#"[1, 2, 3, "four"]"#.to_string(),
Some(&policy),
Duration::from_millis(50),
)
.await;
assert!(result.is_success());
assert!(result.output.is_array());
let arr = result.output.as_array().unwrap();
assert_eq!(arr.len(), 4);
assert_eq!(arr[3], "four");
}
#[test]
fn extract_json_direct_parse() {
let input = r#"{"key": "value"}"#;
let result = extract_json_from_output(input).unwrap();
assert_eq!(result["key"], "value");
}
#[test]
fn extract_json_with_whitespace() {
let input = r#"
{"key": "value"}
"#;
let result = extract_json_from_output(input).unwrap();
assert_eq!(result["key"], "value");
}
#[test]
fn extract_json_from_markdown_json_block() {
let input = r#"Here's the JSON:
```json
{"name": "Thibaut", "score": 42}
```
Hope this helps!"#;
let result = extract_json_from_output(input).unwrap();
assert_eq!(result["name"], "Thibaut");
assert_eq!(result["score"], 42);
}
#[test]
fn extract_json_from_markdown_plain_block() {
let input = r#"The result:
```
{"items": [1, 2, 3]}
```
"#;
let result = extract_json_from_output(input).unwrap();
assert!(result["items"].is_array());
}
#[test]
fn extract_json_from_prose_with_braces() {
let input = r#"I'll generate the fortune for you:
The cosmic reading reveals: {"sign": "scorpio", "lucky_number": 7, "message": "Great things await"}
This is based on ancient wisdom."#;
let result = extract_json_from_output(input).unwrap();
assert_eq!(result["sign"], "scorpio");
assert_eq!(result["lucky_number"], 7);
}
#[test]
fn extract_json_array_from_markdown() {
let input = r#"```json
[{"id": 1}, {"id": 2}, {"id": 3}]
```"#;
let result = extract_json_from_output(input).unwrap();
assert!(result.is_array());
assert_eq!(result.as_array().unwrap().len(), 3);
}
#[test]
fn extract_json_nested_objects() {
let input = r#"Result: {"outer": {"inner": {"deep": "value"}}}"#;
let result = extract_json_from_output(input).unwrap();
assert_eq!(result["outer"]["inner"]["deep"], "value");
}
#[test]
fn extract_json_with_escaped_braces_in_strings() {
let input = r#"{"template": "Use {{variable}} syntax", "count": 1}"#;
let result = extract_json_from_output(input).unwrap();
assert_eq!(result["template"], "Use {{variable}} syntax");
}
#[test]
fn extract_json_no_json_found() {
let input = "This is just plain text without any JSON.";
let result = extract_json_from_output(input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.contains("No JSON object or array found"));
}
#[tokio::test]
async fn make_task_result_handles_markdown_wrapped_json() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None,
max_retries: None,
};
let llm_output = r#"Here's your fortune:
```json
{
"sign": "scorpio",
"lucky_number": 7,
"message": "The stars align in your favor"
}
```
Enjoy your reading!"#;
let result = make_task_result(
llm_output.to_string(),
Some(&policy),
Duration::from_millis(100),
)
.await;
assert!(result.is_success(), "Should parse JSON from markdown block");
assert_eq!(result.output["sign"], "scorpio");
assert_eq!(result.output["lucky_number"], 7);
}
#[tokio::test]
async fn make_task_result_empty_output_returns_null() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None,
max_retries: None,
};
let empty_output = "".to_string();
let result = make_task_result(empty_output, Some(&policy), std::time::Duration::ZERO).await;
assert!(result.is_success(), "Empty output should succeed with null");
assert!(result.output.is_null(), "Empty output should return null");
}
#[tokio::test]
async fn make_task_result_whitespace_output_returns_null() {
use crate::ast::OutputPolicy;
let policy = OutputPolicy {
format: OutputFormat::Json,
schema: None,
max_retries: None,
};
let whitespace_output = " \n\t ".to_string();
let result =
make_task_result(whitespace_output, Some(&policy), std::time::Duration::ZERO).await;
assert!(
result.is_success(),
"Whitespace-only output should succeed with null"
);
assert!(
result.output.is_null(),
"Whitespace-only output should return null"
);
}
#[tokio::test]
async fn validate_schema_ref_inline_success() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let value = serde_json::json!({"name": "test"});
let result = validate_schema_ref(&value, &SchemaRef::Inline(schema)).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn validate_schema_ref_inline_failure() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" }
},
"required": ["name"]
});
let value = serde_json::json!({"other": "field"});
let result = validate_schema_ref(&value, &SchemaRef::Inline(schema)).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("required") || err.contains("name"),
"Error should mention missing required field: {}",
err
);
}
#[test]
fn format_validation_errors_shows_details() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"age": { "type": "integer", "minimum": 0 }
},
"required": ["age"]
});
let value = serde_json::json!({"age": -5});
let errors = format_validation_errors(&value, &schema);
assert!(errors.contains("-5"), "Should show the invalid value");
}
}