use crate::error::{Error, Result};
use crate::glm46::{ChatMessage, ChatRequest, GLM46Client, ResponseFormat};
use crate::verification::types::*;
use std::sync::Arc;
use tokio::sync::RwLock;
pub use proof_parser::ProofFormat;
pub struct ProofGenerator {
client: Arc<GLM46Client>,
}
impl ProofGenerator {
pub fn new(client: GLM46Client) -> Self {
Self {
client: Arc::new(client),
}
}
pub async fn generate_from_execution(&self, trace: &AlgorithmTrace) -> Result<Proof> {
let prompt = self.build_proof_generation_prompt(trace);
let response = self
.client
.chat_completion(ChatRequest {
messages: vec![
ChatMessage::system(GENERATION_SYSTEM_PROMPT),
ChatMessage::user(&prompt),
],
response_format: Some(ResponseFormat::Structured),
temperature: 0.2, max_tokens: 4000,
..Default::default()
})
.await?;
let proof = self.parse_generated_proof(&response.content, trace)?;
Ok(proof)
}
fn build_proof_generation_prompt(&self, trace: &AlgorithmTrace) -> String {
format!(
"Generate a formal mathematical proof for the following algorithm execution:\n\n\
Algorithm: {}\n\
Input: {}\n\
Execution trace:\n{}\n\n\
Output: {}\n\n\
Requirements:\n\
1. Use rigorous mathematical reasoning\n\
2. Include LaTeX formatting for all mathematical expressions\n\
3. Justify each step with appropriate theorems or axioms\n\
4. Structure the proof clearly with step-by-step reasoning\n\
5. Apply relevant theorems from the domain\n\
6. Output in structured JSON format with statements, justifications, and theorems",
trace.algorithm_name,
trace.input,
trace.execution_steps.join("\n"),
trace.output
)
}
fn parse_generated_proof(&self, content: &str, trace: &AlgorithmTrace) -> Result<Proof> {
if let Ok(proof) = serde_json::from_str::<Proof>(content) {
return Ok(proof);
}
self.parse_proof_from_latex(content, trace)
}
fn parse_proof_from_latex(&self, content: &str, trace: &AlgorithmTrace) -> Result<Proof> {
let lines: Vec<&str> = content.lines().collect();
let mut statements = Vec::new();
let mut theorems = Vec::new();
let mut current_latex = String::new();
let mut in_latex = false;
for line in lines {
line.trim();
if line.contains("$$") || line.contains("\\[") || line.contains("\\]") {
in_latex = true;
continue;
}
if line.contains("$") && !line.contains("$$") {
if let Some(start) = line.find("$") {
if let Some(end) = line[start + 1..].find("$") {
let latex = &line[start..=start + 1 + end];
current_latex.push_str(latex);
}
}
continue;
}
if in_latex {
if line.contains("$$") || line.contains("]") {
in_latex = false;
} else {
current_latex.push_str(line);
}
continue;
}
if line.contains("theorem") || line.contains("Theorem") || line.contains("Lemma") {
let words: Vec<&str> = line.split_whitespace().collect();
for (i, word) in words.iter().enumerate() {
if word.to_lowercase().contains("theorem")
|| word.to_lowercase().contains("lemma")
{
if i + 1 < words.len() {
theorems
.push(words[i + 1].trim_matches(&['.', ',', ':'][..]).to_string());
}
}
}
}
if !line.trim().is_empty() && !line.starts_with("#") {
statements.push(MathStatement::Step {
statement: line.to_string(),
latex: if !current_latex.is_empty() {
current_latex.clone()
} else {
line.to_string()
},
justification: "Generated by GLM-4.6".to_string(),
theorems: if !theorems.is_empty() {
vec![theorems.last().unwrap().clone()]
} else {
Vec::new()
},
});
}
}
Ok(Proof {
id: uuid::Uuid::new_v4().to_string(),
problem: trace.algorithm_name.clone(),
method: "direct".to_string(), statements,
answer: Some(trace.output.clone()),
theorems,
confidence: 0.8, metadata: ProofMetadata {
generated_at: chrono::Utc::now().to_rfc3339(),
generated_by: "glm-4.6".to_string(),
step_count: statements.len(),
variant: None,
execution_time_ms: trace.execution_time_ms,
token_usage: TokenUsage {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
},
},
})
}
}
#[derive(Debug, Clone)]
pub struct AlgorithmTrace {
pub algorithm_name: String,
pub input: String,
pub execution_steps: Vec<String>,
pub output: String,
pub execution_time_ms: u64,
}
pub struct ExpansiveProofGenerator {
client: Arc<GLM46Client>,
proofs: Arc<RwLock<Vec<Proof>>>,
max_variants: usize,
}
impl ExpansiveProofGenerator {
pub fn new(client: GLM46Client) -> Self {
Self {
client: Arc::new(client),
proofs: Arc::new(RwLock::new(Vec::new())),
max_variants: 4, }
}
pub async fn generate_variants(&self, problem: &MathProblem) -> Result<Vec<Proof>> {
let proofs = self.proofs.clone();
let client = self.client.clone();
let problem = problem.clone();
let max_variants = self.max_variants;
let methods = vec![
"direct proof",
"proof by contradiction",
"proof by induction",
"proof by contraposition",
];
let mut results = Vec::new();
for method in methods.iter().take(max_variants) {
let variant = self.generate_proof_with_method(&problem, method).await?;
results.push(variant);
}
{
let mut proofs_guard = proofs.write().await;
proofs_guard.extend(results.clone());
}
Ok(results)
}
async fn generate_proof_with_method(
&self,
problem: &MathProblem,
method: &str,
) -> Result<Proof> {
let prompt = format!(
"Generate a {} proof for the following mathematical problem:\n\n\
Problem: {}\n\
Problem statement (LaTeX): {}\n\n\
Requirements:\n\
1. Use the {} method\n\
2. Include step-by-step reasoning\n\
3. Apply relevant theorems\n\
4. Use LaTeX for all mathematical expressions\n\
5. Output in structured JSON format",
method, problem.id, problem.latex, method
);
let response = self
.client
.chat_completion(ChatRequest {
messages: vec![
ChatMessage::system(GENERATION_SYSTEM_PROMPT),
ChatMessage::user(&prompt),
],
response_format: Some(ResponseFormat::Structured),
temperature: 0.15, max_tokens: 4000,
..Default::default()
})
.await?;
Ok(Proof {
id: uuid::Uuid::new_v4().to_string(),
problem: problem.id.clone(),
method: method.to_string(),
statements: vec![MathStatement::Step {
statement: response.content.clone(),
latex: response.content.clone(),
justification: format!("Generated using {} method", method),
theorems: vec![],
}],
answer: problem.answer.clone(),
theorems: vec![],
confidence: 0.85, metadata: ProofMetadata {
generated_at: chrono::Utc::now().to_rfc3339(),
generated_by: "glm-4.6".to_string(),
step_count: 1, variant: Some(method.to_string()),
execution_time_ms: 0, token_usage: TokenUsage {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
},
},
})
}
pub async fn select_best(&self) -> Result<Proof> {
let proofs = self.proofs.read().await;
if proofs.is_empty() {
return Err(Error::Verification("No proofs generated".to_string()));
}
Ok(proofs
.iter()
.max_by(|a, b| a.confidence.partial_cmp(&b.confidence).unwrap())
.unwrap()
.clone())
}
pub async fn get_all_proofs(&self) -> Vec<Proof> {
let proofs = self.proofs.read().await;
proofs.clone()
}
}
#[derive(Debug, Clone)]
pub struct ProofVariant {
pub proof: Proof,
pub method: String,
pub confidence: f64,
}
#[derive(Debug, Clone)]
pub struct MathProblem {
pub id: String,
pub statement: String,
pub latex: String,
pub answer: Option<String>,
}
const GENERATION_SYSTEM_PROMPT: &str = r#"You are an expert mathematician and theorem prover using GLM-4.6's strong mathematical reasoning capabilities (91.0% AIME performance).
Your task is to generate rigorous, step-by-step mathematical proofs with:
1. **Rigorous Logic**: Each step must logically follow from previous steps or known theorems
2. **Clear Justification**: Explain why each step is valid
3. **Theorem Application**: Explicitly reference relevant theorems, lemmas, or axioms
4. **LaTeX Formatting**: Use proper LaTeX notation for all mathematical expressions
5. **Structured Output**: Output in JSON format with clear structure
Always verify your reasoning at each step and ensure mathematical correctness.
Output Format (JSON):
{
"problem": "Problem statement",
"method": "proof method used",
"statements": [
{
"type": "step" | "axiom" | "conclusion",
"statement": "plain text statement",
"latex": "LaTeX representation",
"justification": "why this step is valid",
"theorems": ["theorem1", "theorem2"]
}
],
"answer": "final answer (if applicable)",
"theorems_used": ["theorem1", "theorem2"]
}"#;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_algorithm_trace_creation() {
let trace = AlgorithmTrace {
algorithm_name: "quadratic_formula".to_string(),
input: "ax^2 + bx + c = 0".to_string(),
execution_steps: vec![
"Identify coefficients a, b, c".to_string(),
"Calculate discriminant: b^2 - 4ac".to_string(),
"Apply quadratic formula".to_string(),
],
output: "x = (-b ± √(b^2-4ac)) / 2a".to_string(),
execution_time_ms: 10,
};
assert_eq!(trace.algorithm_name, "quadratic_formula");
assert_eq!(trace.execution_steps.len(), 3);
}
#[test]
fn test_expansive_generator_creation() {
let _generator = ExpansiveProofGenerator {
client: Arc::new(GLM46Client::from_env().unwrap()),
proofs: Arc::new(RwLock::new(Vec::new())),
max_variants: 4,
};
assert_eq!(
ExpansiveProofGenerator::new(GLM46Client::from_env().unwrap()).max_variants,
4
);
}
}