use serde::{Deserialize, Serialize};
pub const REFINEMENT_PROMPT_TEMPLATE: &str = r#"You are a data modeling expert. Your task is to refine an automatically inferred JSON schema based on the provided context.
## Rules
1. NEVER rename any fields - preserve all original field names exactly
2. NEVER remove any fields from the schema
3. Only make ADDITIVE changes:
- Add descriptions to fields
- Add format hints (e.g., "date-time", "email", "uri")
- Narrow types where appropriate (e.g., "string" -> "string with format")
- Add enum values if a field has a known set of values
- Add constraints (minLength, maxLength, minimum, maximum, pattern)
4. If documentation mentions a field's purpose, add it as a description
5. If sample data shows a pattern, add appropriate format or constraints
## Input Schema
```json
{schema}
```
{documentation_section}
{samples_section}
## Output
Return ONLY a valid JSON object with the refined schema. Do not include any explanation or markdown formatting.
The output must be valid JSON that can be parsed directly."#;
pub const FIELD_DESCRIPTION_PROMPT: &str = r#"Based on the following context, provide a brief description (1-2 sentences) for a database field.
Field name: {field_name}
Field type: {field_type}
Sample values: {sample_values}
{documentation_context}
Return only the description text, no JSON or formatting."#;
#[derive(Debug, Clone, Default)]
pub struct PromptContext {
pub schema_json: String,
pub documentation: Option<String>,
pub samples: Vec<String>,
pub max_tokens: usize,
}
impl PromptContext {
pub fn new(schema_json: impl Into<String>) -> Self {
Self {
schema_json: schema_json.into(),
documentation: None,
samples: Vec::new(),
max_tokens: 4096,
}
}
pub fn with_documentation(mut self, doc: impl Into<String>) -> Self {
self.documentation = Some(doc.into());
self
}
pub fn with_samples(mut self, samples: Vec<String>) -> Self {
self.samples = samples;
self
}
pub fn with_max_tokens(mut self, max: usize) -> Self {
self.max_tokens = max;
self
}
pub fn build_prompt(&self) -> String {
let documentation_section = if let Some(doc) = &self.documentation {
format!(
"## Documentation Context\n```\n{}\n```\n",
truncate_to_tokens(doc, self.max_tokens / 4)
)
} else {
String::new()
};
let samples_section = if !self.samples.is_empty() {
let samples_text = self
.samples
.iter()
.take(5)
.enumerate()
.map(|(i, s)| format!("Sample {}: {}", i + 1, s))
.collect::<Vec<_>>()
.join("\n");
format!(
"## Sample Records\n```json\n{}\n```\n",
truncate_to_tokens(&samples_text, self.max_tokens / 4)
)
} else {
String::new()
};
REFINEMENT_PROMPT_TEMPLATE
.replace("{schema}", &self.schema_json)
.replace("{documentation_section}", &documentation_section)
.replace("{samples_section}", &samples_section)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedRefinement {
pub schema: serde_json::Value,
pub warnings: Vec<String>,
}
pub fn parse_llm_response(response: &str) -> Result<ParsedRefinement, String> {
let warnings = Vec::new();
let json_str = extract_json(response);
let schema: serde_json::Value = serde_json::from_str(&json_str).map_err(|e| {
format!(
"Failed to parse JSON: {}. Response was: {}",
e,
&response[..response.len().min(200)]
)
})?;
if !schema.is_object() {
return Err("Response is not a JSON object".to_string());
}
Ok(ParsedRefinement { schema, warnings })
}
fn extract_json(response: &str) -> String {
let trimmed = response.trim();
if let Some(start) = trimmed.find("```json") {
let content_start = start + 7;
if let Some(end) = trimmed[content_start..].find("```") {
return trimmed[content_start..content_start + end]
.trim()
.to_string();
}
}
if let Some(start) = trimmed.find("```") {
let content_start = start + 3;
let content_start = trimmed[content_start..]
.find('\n')
.map(|n| content_start + n + 1)
.unwrap_or(content_start);
if let Some(end) = trimmed[content_start..].find("```") {
return trimmed[content_start..content_start + end]
.trim()
.to_string();
}
}
if let Some(start) = trimmed.find('{') {
if let Some(end) = trimmed.rfind('}') {
if end > start {
return trimmed[start..=end].to_string();
}
}
}
trimmed.to_string()
}
fn truncate_to_tokens(text: &str, max_tokens: usize) -> String {
let max_chars = max_tokens * 4;
if text.len() <= max_chars {
return text.to_string();
}
let truncated = &text[..max_chars];
if let Some(last_space) = truncated.rfind(' ') {
format!("{}...", &truncated[..last_space])
} else {
format!("{}...", truncated)
}
}
pub fn estimate_tokens(text: &str) -> usize {
(text.len() + 3) / 4
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_context_basic() {
let schema = r#"{"type": "object", "properties": {"name": {"type": "string"}}}"#;
let context = PromptContext::new(schema);
let prompt = context.build_prompt();
assert!(prompt.contains("data modeling expert"));
assert!(prompt.contains(schema));
assert!(!prompt.contains("Documentation Context"));
assert!(!prompt.contains("Sample Records"));
}
#[test]
fn test_prompt_context_with_documentation() {
let schema = r#"{"type": "object"}"#;
let context =
PromptContext::new(schema).with_documentation("This table stores customer information");
let prompt = context.build_prompt();
assert!(prompt.contains("Documentation Context"));
assert!(prompt.contains("customer information"));
}
#[test]
fn test_prompt_context_with_samples() {
let schema = r#"{"type": "object"}"#;
let samples = vec![
r#"{"name": "Alice", "age": 30}"#.to_string(),
r#"{"name": "Bob", "age": 25}"#.to_string(),
];
let context = PromptContext::new(schema).with_samples(samples);
let prompt = context.build_prompt();
assert!(prompt.contains("Sample Records"));
assert!(prompt.contains("Alice"));
assert!(prompt.contains("Sample 1"));
assert!(prompt.contains("Sample 2"));
}
#[test]
fn test_parse_llm_response_pure_json() {
let response = r#"{"type": "object", "properties": {}}"#;
let result = parse_llm_response(response).unwrap();
assert!(result.schema.is_object());
}
#[test]
fn test_parse_llm_response_markdown_json() {
let response = r#"Here's the refined schema:
```json
{"type": "object", "properties": {"name": {"type": "string"}}}
```
This schema includes..."#;
let result = parse_llm_response(response).unwrap();
assert!(result.schema.is_object());
assert!(result.schema.get("properties").is_some());
}
#[test]
fn test_parse_llm_response_generic_code_block() {
let response = r#"```
{"type": "object"}
```"#;
let result = parse_llm_response(response).unwrap();
assert!(result.schema.is_object());
}
#[test]
fn test_parse_llm_response_with_text() {
let response = r#"Based on my analysis, the refined schema is:
{"type": "object", "properties": {"id": {"type": "integer"}}}
I've added the integer type for the id field."#;
let result = parse_llm_response(response).unwrap();
assert!(result.schema.is_object());
}
#[test]
fn test_parse_llm_response_invalid() {
let response = "This is not valid JSON";
let result = parse_llm_response(response);
assert!(result.is_err());
}
#[test]
fn test_extract_json_code_block() {
let input = "```json\n{\"key\": \"value\"}\n```";
let result = extract_json(input);
assert_eq!(result, "{\"key\": \"value\"}");
}
#[test]
fn test_extract_json_direct() {
let input = "Some text {\"key\": \"value\"} more text";
let result = extract_json(input);
assert_eq!(result, "{\"key\": \"value\"}");
}
#[test]
fn test_truncate_to_tokens() {
let short = "Hello world";
assert_eq!(truncate_to_tokens(short, 100), short);
let long = "a ".repeat(100);
let truncated = truncate_to_tokens(&long, 10);
assert!(truncated.len() < long.len());
assert!(truncated.ends_with("..."));
}
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("test"), 1);
assert_eq!(estimate_tokens("hello world"), 3);
assert_eq!(estimate_tokens(&"a".repeat(100)), 25);
}
}