use std::sync::Arc;
use async_trait::async_trait;
use nab::AcceleratedClient;
use nab::analyze::active_reading::{
ActiveReadingError, LlmSampler, Reference, ReferenceKind, Result, UrlFetcher,
};
use rust_mcp_sdk::McpServer;
use serde::Deserialize;
use tracing::debug;
pub struct McpLlmSampler {
runtime: Arc<dyn McpServer>,
}
impl McpLlmSampler {
pub fn new(runtime: Arc<dyn McpServer>) -> Self {
Self { runtime }
}
}
#[async_trait]
impl LlmSampler for McpLlmSampler {
async fn identify_references(
&self,
chunk: &str,
segment_offset: usize,
) -> Result<Vec<Reference>> {
let prompt = build_identify_prompt(chunk);
debug!(
segment_offset,
chunk_len = chunk.len(),
"active reading: sampling identify_references"
);
let response_text = crate::sampling::create_message(&self.runtime, &prompt, 500, None)
.await
.map_err(|e| ActiveReadingError::SamplingFailed(e.to_string()))?;
parse_references_response(&response_text, segment_offset)
}
async fn summarize(&self, content: &str, query: &str, max_tokens: u32) -> Result<String> {
let trimmed = trim_content(content);
let prompt = format!(
"Summarize this content in ~150 words, focusing on what answers the query.\n\
Query: {query}\n\nContent:\n{trimmed}"
);
crate::sampling::create_message(&self.runtime, &prompt, max_tokens, None)
.await
.map_err(|e| ActiveReadingError::SamplingFailed(e.to_string()))
}
}
pub struct NabUrlFetcher {
client: Arc<AcceleratedClient>,
}
impl NabUrlFetcher {
pub fn new(client: Arc<AcceleratedClient>) -> Self {
Self { client }
}
}
#[async_trait]
impl UrlFetcher for NabUrlFetcher {
async fn fetch_text(&self, url: &str) -> Result<String> {
self.client
.fetch_text(url)
.await
.map_err(|e| ActiveReadingError::FetchFailed(e.to_string()))
}
}
fn build_identify_prompt(chunk: &str) -> String {
format!(
"You are analyzing a video transcript chunk. Identify references that warrant lookup.\n\
Return ONLY valid JSON in this exact format:\n\
{{\"refs\": [{{\"kind\": \"paper|person|tool|claim|number|other\", \
\"query\": \"...\", \"confidence\": 0.0-1.0}}]}}\n\
Be conservative — only flag concrete, lookupable items. No more than 5 per chunk.\n\
Transcript chunk:\n{chunk}"
)
}
fn trim_content(content: &str) -> &str {
const MAX_CHARS: usize = 4_000;
if content.len() <= MAX_CHARS {
content
} else {
let mut end = MAX_CHARS;
while !content.is_char_boundary(end) {
end -= 1;
}
&content[..end]
}
}
#[derive(Deserialize)]
struct ParseReferencesResponse {
refs: Vec<ParseReferencesRawRef>,
}
#[derive(Deserialize)]
struct ParseReferencesRawRef {
kind: String,
query: String,
confidence: f32,
}
pub(crate) fn parse_references_response(
text: &str,
segment_offset: usize,
) -> Result<Vec<Reference>> {
let cleaned = strip_code_fences(text);
let parsed: ParseReferencesResponse = serde_json::from_str(cleaned)
.map_err(|e| ActiveReadingError::InvalidResponse(format!("JSON parse: {e}")))?;
Ok(parsed
.refs
.into_iter()
.map(|r| Reference {
kind: kind_from_str(&r.kind),
query: r.query,
confidence: r.confidence,
segment_idx: segment_offset,
})
.collect())
}
fn strip_code_fences(text: &str) -> &str {
let s = text.trim();
let s = s.strip_prefix("```json").unwrap_or(s);
let s = s.strip_prefix("```").unwrap_or(s);
let s = s.strip_suffix("```").unwrap_or(s);
s.trim()
}
fn kind_from_str(s: &str) -> ReferenceKind {
match s.to_ascii_lowercase().as_str() {
"paper" => ReferenceKind::Paper,
"person" => ReferenceKind::Person,
"tool" => ReferenceKind::Tool,
"claim" => ReferenceKind::Claim,
"number" => ReferenceKind::Number,
_ => ReferenceKind::Other,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_references_response_handles_bare_json() {
let text = r#"{"refs": [{"kind": "paper", "query": "Dijkstra 1968", "confidence": 0.95}]}"#;
let refs = parse_references_response(text, 3).unwrap();
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].kind, ReferenceKind::Paper);
assert_eq!(refs[0].query, "Dijkstra 1968");
assert!((refs[0].confidence - 0.95).abs() < 0.01);
assert_eq!(refs[0].segment_idx, 3);
}
#[test]
fn parse_references_response_handles_markdown_fences() {
let text = "```json\n{\"refs\": [{\"kind\": \"person\", \"query\": \"Geoffrey Hinton\", \"confidence\": 0.9}]}\n```";
let refs = parse_references_response(text, 0).unwrap();
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].kind, ReferenceKind::Person);
assert_eq!(refs[0].query, "Geoffrey Hinton");
}
#[test]
fn parse_references_response_handles_plain_fences() {
let text = "```\n{\"refs\": [{\"kind\": \"tool\", \"query\": \"ripgrep\", \"confidence\": 0.85}]}\n```";
let refs = parse_references_response(text, 1).unwrap();
assert_eq!(refs.len(), 1);
assert_eq!(refs[0].kind, ReferenceKind::Tool);
}
#[test]
fn parse_references_response_errors_on_malformed_json() {
let text = "this is not JSON at all";
let result = parse_references_response(text, 0);
assert!(matches!(
result,
Err(ActiveReadingError::InvalidResponse(_))
));
}
#[test]
fn parse_references_response_handles_empty_refs() {
let text = r#"{"refs": []}"#;
let refs = parse_references_response(text, 0).unwrap();
assert!(refs.is_empty());
}
#[test]
fn kind_from_str_unknown_kind_becomes_other() {
let kind = kind_from_str("widget");
assert_eq!(kind, ReferenceKind::Other);
}
#[test]
fn trim_content_short_is_unchanged() {
let content = "Hello, world!";
let result = trim_content(content);
assert_eq!(result, content);
}
#[test]
fn trim_content_long_is_truncated() {
let content = "a".repeat(8_000);
let result = trim_content(&content);
assert!(result.len() <= 4_000);
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
}
}