use std::sync::Arc;
use async_trait::async_trait;
use serde_json::Value;
use crate::llm::{BaseLlm, LlmError, LlmRequest};
use super::transcript::TranscriptTurn;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExtractionTrigger {
EveryTurn,
Interval(u32),
AfterToolCall,
OnPhaseChange,
OnGenerationComplete,
}
fn strip_code_fences(text: &str) -> &str {
let trimmed = text.trim();
if let Some(rest) = trimmed.strip_prefix("```") {
let rest = rest.trim_start_matches(|c: char| c != '\n');
let rest = rest.strip_prefix('\n').unwrap_or(rest);
let rest = rest.trim_end();
rest.strip_suffix("```").unwrap_or(rest).trim()
} else {
trimmed
}
}
#[async_trait]
pub trait TurnExtractor: Send + Sync {
fn name(&self) -> &str;
fn window_size(&self) -> usize;
fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
let _ = window;
true
}
fn trigger(&self) -> ExtractionTrigger {
ExtractionTrigger::EveryTurn
}
async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError>;
}
pub struct LlmExtractor {
name: String,
llm: Arc<dyn BaseLlm>,
prompt: String,
window_size: usize,
schema: Option<Value>,
schema_str: Option<String>,
min_words: usize,
trigger: ExtractionTrigger,
}
impl LlmExtractor {
pub fn new(
name: impl Into<String>,
llm: Arc<dyn BaseLlm>,
prompt: impl Into<String>,
window_size: usize,
) -> Self {
Self {
name: name.into(),
llm,
prompt: prompt.into(),
window_size,
schema: None,
schema_str: None,
min_words: 0,
trigger: ExtractionTrigger::EveryTurn,
}
}
pub fn with_min_words(mut self, n: usize) -> Self {
self.min_words = n;
self
}
pub fn with_schema(mut self, schema: Value) -> Self {
self.schema_str = serde_json::to_string_pretty(&schema).ok();
self.schema = Some(schema);
self
}
pub fn with_trigger(mut self, trigger: ExtractionTrigger) -> Self {
self.trigger = trigger;
self
}
fn format_transcript(window: &[TranscriptTurn]) -> String {
let mut out = String::new();
for turn in window {
if !turn.user.is_empty() {
out.push_str("User: ");
out.push_str(turn.user.trim());
out.push('\n');
}
if !turn.model.is_empty() {
out.push_str("Assistant: ");
out.push_str(turn.model.trim());
out.push('\n');
}
out.push('\n');
}
out
}
}
#[async_trait]
impl TurnExtractor for LlmExtractor {
fn name(&self) -> &str {
&self.name
}
fn window_size(&self) -> usize {
self.window_size
}
fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
if self.min_words == 0 {
return true;
}
window
.iter()
.rev()
.find(|t| !t.user.is_empty())
.is_some_and(|t| t.user.split_whitespace().count() >= self.min_words)
}
fn trigger(&self) -> ExtractionTrigger {
self.trigger.clone()
}
async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError> {
let transcript = Self::format_transcript(window);
let mut request = LlmRequest::from_text(format!(
"Transcript:\n{transcript}\nExtract the requested information."
));
request.system_instruction = Some(self.prompt.clone());
if let Some(ref schema) = self.schema {
request.response_mime_type = Some("application/json".to_string());
request.response_json_schema = Some(schema.clone());
} else {
request.response_mime_type = Some("application/json".to_string());
}
let response = self.llm.generate(request).await?;
let text = response.text();
let cleaned = strip_code_fences(&text);
serde_json::from_str(cleaned).map_err(|e| {
LlmError::Other(format!(
"Failed to parse extraction result as JSON: {e}. Raw: {text}"
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::LlmResponse;
use rs_genai::prelude::{Content, Part, Role};
use std::time::Instant;
struct MockLlm {
response: String,
}
#[async_trait]
impl BaseLlm for MockLlm {
fn model_id(&self) -> &str {
"mock"
}
async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
Ok(LlmResponse {
content: Content {
role: Some(Role::Model),
parts: vec![Part::Text {
text: self.response.clone(),
}],
},
finish_reason: Some("STOP".into()),
usage: None,
})
}
}
fn make_turns(pairs: &[(&str, &str)]) -> Vec<TranscriptTurn> {
pairs
.iter()
.enumerate()
.map(|(i, (user, model))| TranscriptTurn {
turn_number: i as u32,
user: user.to_string(),
model: model.to_string(),
tool_calls: Vec::new(),
timestamp: Instant::now(),
})
.collect()
}
#[tokio::test]
async fn llm_extractor_produces_json() {
let llm = Arc::new(MockLlm {
response: r#"{"phase": "ordering", "items": ["pizza"]}"#.to_string(),
});
let extractor = LlmExtractor::new("OrderState", llm, "Extract order state", 3);
let turns = make_turns(&[
("I'd like a pizza", "Great! What size?"),
("Large please", "Coming right up!"),
]);
let result = extractor.extract(&turns).await.unwrap();
assert_eq!(result["phase"], "ordering");
assert_eq!(result["items"][0], "pizza");
}
#[tokio::test]
async fn llm_extractor_with_schema() {
let llm = Arc::new(MockLlm {
response: r#"{"sentiment": "positive", "score": 0.9}"#.to_string(),
});
let schema = serde_json::json!({
"type": "object",
"properties": {
"sentiment": {"type": "string", "enum": ["positive", "neutral", "negative"]},
"score": {"type": "number"}
}
});
let extractor =
LlmExtractor::new("Sentiment", llm, "Rate sentiment", 1).with_schema(schema);
let turns = make_turns(&[("This is great!", "Glad you think so!")]);
let result = extractor.extract(&turns).await.unwrap();
assert_eq!(result["sentiment"], "positive");
}
#[tokio::test]
async fn llm_extractor_invalid_json_returns_error() {
let llm = Arc::new(MockLlm {
response: "not json at all".to_string(),
});
let extractor = LlmExtractor::new("Bad", llm, "Extract", 1);
let turns = make_turns(&[("hi", "hello")]);
let result = extractor.extract(&turns).await;
assert!(result.is_err());
}
#[test]
fn format_transcript_readable() {
let turns = make_turns(&[("Hello", "Hi there!"), ("How are you?", "I'm doing well")]);
let formatted = LlmExtractor::format_transcript(&turns);
assert!(formatted.contains("User: Hello"));
assert!(formatted.contains("Assistant: Hi there!"));
assert!(formatted.contains("User: How are you?"));
}
#[tokio::test]
async fn llm_extractor_handles_markdown_fenced_json() {
let llm = Arc::new(MockLlm {
response: "```json\n{\"status\": \"ok\"}\n```".to_string(),
});
let extractor = LlmExtractor::new("Fenced", llm, "Extract", 1);
let turns = make_turns(&[("test", "reply")]);
let result = extractor.extract(&turns).await.unwrap();
assert_eq!(result["status"], "ok");
}
#[test]
fn strip_code_fences_variants() {
assert_eq!(super::strip_code_fences("```json\n{}\n```"), "{}");
assert_eq!(super::strip_code_fences("```\n{}\n```"), "{}");
assert_eq!(
super::strip_code_fences(" ```json\n{\"a\":1}\n``` "),
"{\"a\":1}"
);
assert_eq!(
super::strip_code_fences("{\"bare\":true}"),
"{\"bare\":true}"
);
}
#[test]
fn extractor_name_and_window_size() {
let llm = Arc::new(MockLlm {
response: "{}".to_string(),
});
let ext = LlmExtractor::new("TestExtractor", llm, "test", 5);
assert_eq!(ext.name(), "TestExtractor");
assert_eq!(ext.window_size(), 5);
}
#[test]
fn extractor_default_trigger_is_every_turn() {
let llm = Arc::new(MockLlm {
response: "{}".to_string(),
});
let ext = LlmExtractor::new("Test", llm, "test", 5);
assert_eq!(ext.trigger(), ExtractionTrigger::EveryTurn);
}
#[test]
fn extractor_with_trigger() {
let llm = Arc::new(MockLlm {
response: "{}".to_string(),
});
let ext = LlmExtractor::new("Test", llm, "test", 5)
.with_trigger(ExtractionTrigger::AfterToolCall);
assert_eq!(ext.trigger(), ExtractionTrigger::AfterToolCall);
}
}