use std::ops::Range;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
use crate::extract::{EntitySpan, Extractor, RelationSpan};
use crate::types::Section;
pub const DEFAULT_OLLAMA_URL: &str = "http://127.0.0.1:11434";
pub const DEFAULT_OLLAMA_MODEL: &str = "llama3.2:3b";
pub const LLM_ENTITY_CONFIDENCE: f32 = 0.75;
pub const LLM_RELATION_CONFIDENCE: f32 = 0.65;
const OLLAMA_TIMEOUT: Duration = Duration::from_secs(30);
pub struct OllamaExtractor {
client: reqwest::blocking::Client,
base_url: String,
model: String,
}
impl std::fmt::Debug for OllamaExtractor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OllamaExtractor")
.field("base_url", &self.base_url)
.field("model", &self.model)
.finish_non_exhaustive()
}
}
impl OllamaExtractor {
pub fn new(
base_url: impl Into<String>,
model: impl Into<String>,
) -> Result<Self, crate::Error> {
let client = reqwest::blocking::Client::builder()
.timeout(OLLAMA_TIMEOUT)
.build()
.map_err(|e| crate::Error::Extractor(format!("ollama client init: {e}")))?;
Ok(Self {
client,
base_url: base_url.into().trim_end_matches('/').to_string(),
model: model.into(),
})
}
fn invoke(&self, section_text: &str) -> LlmPayload {
let schema = serde_json::json!({
"type": "object",
"properties": {
"entities": {
"type": "array",
"items": {
"type": "object",
"properties": {
"kind": { "type": "string" },
"text": { "type": "string" },
"start": { "type": "integer" },
"end": { "type": "integer" }
},
"required": ["kind", "text", "start", "end"]
}
},
"relations": {
"type": "array",
"items": {
"type": "object",
"properties": {
"kind": { "type": "string" },
"subj": { "type": "integer" },
"obj": { "type": "integer" }
},
"required": ["kind", "subj", "obj"]
}
}
},
"required": ["entities", "relations"]
});
let prompt = format!(
"Extract entities and relations from the following text. \
Return STRICTLY JSON matching the schema. \
Entity 'start' and 'end' are byte offsets into the text. \
Entity 'kind' is a descriptive label for the entity type (any label is valid, \
e.g. person, organization, location, product, event, chemical, concept). \
Relation 'kind' is one of: co_occurs_with, acts_on. \
Relation 'subj' and 'obj' are indices into the entities array.\n\n\
TEXT:\n{section_text}"
);
let body = serde_json::json!({
"model": self.model,
"prompt": prompt,
"format": schema,
"stream": false,
});
let url = format!("{}/api/generate", self.base_url);
let resp = match self.client.post(&url).json(&body).send() {
Ok(r) => r,
Err(e) => {
warn!(error = %e, url = %url, "ollama request failed; falling back to empty extract");
return LlmPayload::default();
}
};
if !resp.status().is_success() {
warn!(status = %resp.status(), "ollama returned non-2xx; fallback");
return LlmPayload::default();
}
let envelope: OllamaEnvelope = match resp.json() {
Ok(v) => v,
Err(e) => {
warn!(error = %e, "ollama envelope not JSON; fallback");
return LlmPayload::default();
}
};
match serde_json::from_str::<LlmPayload>(&envelope.response) {
Ok(p) => {
debug!(
entities = p.entities.len(),
relations = p.relations.len(),
"ollama payload parsed"
);
p
}
Err(e) => {
warn!(error = %e, "ollama payload schema rejected; fallback");
LlmPayload::default()
}
}
}
}
impl Extractor for OllamaExtractor {
fn extract_entities(&self, section: &Section) -> Vec<EntitySpan> {
let payload = self.invoke(§ion.text);
payload
.entities
.into_iter()
.filter_map(|e| verify_entity(e, §ion.text))
.collect()
}
fn extract_relations(&self, entities: &[EntitySpan], section: &Section) -> Vec<RelationSpan> {
let payload = self.invoke(§ion.text);
payload
.relations
.into_iter()
.filter(|r| (r.subj as usize) < entities.len() && (r.obj as usize) < entities.len())
.map(|r| RelationSpan {
kind: r.kind,
subject_span: r.subj as usize,
object_span: r.obj as usize,
confidence: LLM_RELATION_CONFIDENCE,
})
.collect()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct LlmPayload {
#[serde(default)]
entities: Vec<LlmEntity>,
#[serde(default)]
relations: Vec<LlmRelation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LlmEntity {
kind: String,
text: String,
start: usize,
end: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LlmRelation {
kind: String,
subj: u32,
obj: u32,
}
#[derive(Debug, Clone, Deserialize)]
struct OllamaEnvelope {
response: String,
}
fn verify_entity(e: LlmEntity, section_text: &str) -> Option<EntitySpan> {
if e.start > e.end || e.end > section_text.len() {
return None;
}
if !section_text.is_char_boundary(e.start) || !section_text.is_char_boundary(e.end) {
return None;
}
let slice = §ion_text[e.start..e.end];
if slice != e.text {
return None;
}
let kind = e.kind.trim().to_string();
if kind.is_empty() {
return None;
}
Some(EntitySpan {
kind,
text: e.text,
byte_range: Range {
start: e.start,
end: e.end,
},
confidence: LLM_ENTITY_CONFIDENCE,
})
}
#[cfg(test)]
mod tests {
use super::*;
use httpmock::prelude::*;
fn make_section(body: &str) -> Section {
Section {
heading: None,
depth: 0,
text: body.to_string(),
byte_range: 0..body.len(),
}
}
fn ok_response(payload: &serde_json::Value) -> serde_json::Value {
serde_json::json!({
"model": "llama3.2:3b",
"created_at": "2026-04-24T00:00:00Z",
"response": payload.to_string(),
"done": true
})
}
#[test]
fn valid_schema_round_trips_verified_spans() {
let server = MockServer::start();
let payload = serde_json::json!({
"entities": [
{ "kind": "person", "text": "Alice", "start": 0, "end": 5 },
{ "kind": "person", "text": "Bob", "start": 10, "end": 13 }
],
"relations": [
{ "kind": "co_occurs_with", "subj": 0, "obj": 1 }
]
});
let _mock = server.mock(|when, then| {
when.method(POST).path("/api/generate");
then.status(200)
.header("content-type", "application/json")
.json_body(ok_response(&payload));
});
let ex = OllamaExtractor::new(server.base_url(), "t").unwrap();
let sec = make_section("Alice met Bob.");
let ents = ex.extract_entities(&sec);
assert_eq!(ents.len(), 2);
assert_eq!(ents[0].text, "Alice");
assert!((ents[0].confidence - LLM_ENTITY_CONFIDENCE).abs() < f32::EPSILON);
let rels = ex.extract_relations(&ents, &sec);
assert_eq!(rels.len(), 1);
assert_eq!(rels[0].kind, "co_occurs_with");
}
#[test]
fn hallucinated_span_is_rejected() {
let server = MockServer::start();
let payload = serde_json::json!({
"entities": [
{ "kind": "person", "text": "Charlie", "start": 0, "end": 7 }
],
"relations": []
});
let _mock = server.mock(|when, then| {
when.method(POST).path("/api/generate");
then.status(200).json_body(ok_response(&payload));
});
let ex = OllamaExtractor::new(server.base_url(), "t").unwrap();
let sec = make_section("Alice met Bob.");
assert!(ex.extract_entities(&sec).is_empty());
}
#[test]
fn http_500_falls_back_to_empty() {
let server = MockServer::start();
let _mock = server.mock(|when, then| {
when.method(POST).path("/api/generate");
then.status(500);
});
let ex = OllamaExtractor::new(server.base_url(), "t").unwrap();
let sec = make_section("whatever");
assert!(ex.extract_entities(&sec).is_empty());
}
#[test]
fn schema_invalid_response_is_rejected() {
let server = MockServer::start();
let _mock = server.mock(|when, then| {
when.method(POST).path("/api/generate");
then.status(200).json_body(serde_json::json!({
"model": "t",
"response": "not-json-at-all",
"done": true
}));
});
let ex = OllamaExtractor::new(server.base_url(), "t").unwrap();
let sec = make_section("Alice met Bob.");
assert!(ex.extract_entities(&sec).is_empty());
}
}