use std::collections::HashMap;
use std::time::Duration;
use serde::Deserialize;
use tracing::{debug, warn};
use crate::candidate::LlmCandidate;
const GENERATE_TIMEOUT_SECS: u64 = 120;
const HEALTH_TIMEOUT_SECS: u64 = 3;
const CONNECT_TIMEOUT_SECS: u64 = 5;
#[derive(Clone, Debug)]
pub struct OllamaAdapter {
pub base_url: String,
pub model: String,
pub temperature: f32,
pub num_ctx: u32,
}
impl Default for OllamaAdapter {
fn default() -> Self {
Self {
base_url: "http://localhost:11434".into(),
model: "qwen2.5-coder:14b".into(),
temperature: 0.1,
num_ctx: 8192,
}
}
}
pub trait OllamaSettings {
fn base_url(&self) -> &str;
fn model(&self) -> &str;
fn temperature(&self) -> f32;
fn num_ctx(&self) -> u32;
}
impl OllamaAdapter {
#[must_use]
pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
model: model.into(),
temperature: 0.1,
num_ctx: 8192,
}
}
#[must_use]
pub fn from_config<S: OllamaSettings>(settings: &S) -> Self {
Self {
base_url: settings.base_url().to_string(),
model: settings.model().to_string(),
temperature: settings.temperature(),
num_ctx: settings.num_ctx(),
}
}
#[must_use]
pub fn from_parts(
base_url: impl Into<String>,
model: impl Into<String>,
temperature: f32,
num_ctx: u32,
) -> Self {
Self {
base_url: base_url.into(),
model: model.into(),
temperature,
num_ctx,
}
}
#[must_use]
pub fn is_available(&self) -> bool {
let url = format!("{}/api/tags", self.base_url);
let client = match build_client(HEALTH_TIMEOUT_SECS) {
Ok(c) => c,
Err(e) => {
warn!(url = %url, error = %e, "failed to build HTTP client for health check");
return false;
}
};
match client.get(&url).send() {
Ok(resp) => {
let ok = resp.status().is_success();
if !ok {
warn!(url = %url, status = %resp.status(), "Ollama health check failed");
}
ok
}
Err(e) => {
warn!(url = %url, error = %e, "Ollama health check request failed");
false
}
}
}
#[allow(clippy::missing_errors_doc)]
pub fn generate(
&self,
prompt: &str,
input_source_ids: &[String],
) -> Result<LlmCandidate, LlmError> {
let candidate = self.post_generate(prompt)?;
crate::validator::validate(&candidate, input_source_ids)?;
Ok(candidate)
}
#[allow(clippy::missing_errors_doc)]
pub fn generate_with_grounding<S: std::hash::BuildHasher>(
&self,
prompt: &str,
input_source_ids: &[String],
source_content: &HashMap<String, String, S>,
) -> Result<LlmCandidate, LlmError> {
let candidate = self.post_generate(prompt)?;
crate::validator::validate_with_sources(&candidate, input_source_ids, source_content)?;
Ok(candidate)
}
fn post_generate(&self, prompt: &str) -> Result<LlmCandidate, LlmError> {
let url = format!("{}/api/generate", self.base_url);
debug!(
model = %self.model,
prompt_len = prompt.len(),
"sending generate request to Ollama"
);
let body = serde_json::json!({
"model": self.model,
"prompt": prompt,
"format": "json",
"stream": false,
"options": {
"temperature": self.temperature,
"num_ctx": self.num_ctx,
}
});
let client = build_client(GENERATE_TIMEOUT_SECS)?;
let http_resp = client
.post(&url)
.json(&body)
.send()
.map_err(|e| LlmError::Http(e.to_string()))?;
if !http_resp.status().is_success() {
let status = http_resp.status();
let text = http_resp.text().unwrap_or_default();
warn!(status = %status, body = %text, "Ollama returned non-2xx");
return Err(LlmError::Http(format!("HTTP {status}: {text}")));
}
let raw: OllamaGenerateResponse = http_resp
.json()
.map_err(|e| LlmError::Http(e.to_string()))?;
debug!("received Ollama response, deserialising candidate");
serde_json::from_str(&raw.response).map_err(|e| LlmError::SchemaValidation(e.to_string()))
}
}
fn build_client(timeout_secs: u64) -> Result<reqwest::blocking::Client, LlmError> {
reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.connect_timeout(Duration::from_secs(CONNECT_TIMEOUT_SECS))
.build()
.map_err(|e| LlmError::Http(e.to_string()))
}
#[derive(Debug, Deserialize)]
struct OllamaGenerateResponse {
response: String,
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum LlmError {
#[error("HTTP error: {0}")]
Http(String),
#[error("JSON schema validation failed: {0}")]
SchemaValidation(String),
#[error("source ID '{0}' not found in input sources")]
SourceIdMissing(String),
#[error("claim type {0:?} is not allowed from LLM output (ADR 0002)")]
ForbiddenClaimType(crate::candidate::ClaimType),
#[error("claim '{claim_text}' is not grounded in cited sources {source_ids:?}")]
SourceNotGrounded {
claim_text: String,
source_ids: Vec<String>,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::candidate::{ClaimConfidence, ClaimType, LlmClaim, SCHEMA};
fn wrap_candidate(candidate: &LlmCandidate) -> serde_json::Value {
let inner = serde_json::to_string(candidate).expect("serialise");
serde_json::json!({ "response": inner, "model": "qwen2.5-coder:14b", "done": true })
}
fn valid_candidate(source_ids: Vec<String>) -> LlmCandidate {
LlmCandidate {
schema: SCHEMA.into(),
candidate_id: "test-cand-001".into(),
input_source_ids: source_ids.clone(),
claims: vec![LlmClaim {
text: "the project uses a deterministic-first approach".into(),
claim_type: ClaimType::CandidateObservation,
source_ids,
confidence: ClaimConfidence::Candidate,
}],
}
}
#[test]
fn deserialise_ollama_response() {
let candidate = valid_candidate(vec!["src_abc".into()]);
let body = wrap_candidate(&candidate);
let raw: OllamaGenerateResponse = serde_json::from_value(body).expect("parse ollama body");
let parsed: LlmCandidate = serde_json::from_str(&raw.response).expect("parse candidate");
assert_eq!(parsed.schema, SCHEMA);
assert_eq!(parsed.candidate_id, "test-cand-001");
assert_eq!(parsed.claims.len(), 1);
assert_eq!(parsed.claims[0].claim_type, ClaimType::CandidateObservation);
}
#[test]
fn hard_rule_claim_rejected_by_validate() {
let candidate = LlmCandidate {
schema: SCHEMA.into(),
candidate_id: "bad-cand".into(),
input_source_ids: vec!["src_abc".into()],
claims: vec![LlmClaim {
text: "always vendor all dependencies".into(),
claim_type: ClaimType::HardRule,
source_ids: vec!["src_abc".into()],
confidence: ClaimConfidence::Candidate,
}],
};
let result = crate::validator::validate(&candidate, &["src_abc".into()]);
assert!(
matches!(
result,
Err(LlmError::ForbiddenClaimType(ClaimType::HardRule))
),
"expected ForbiddenClaimType(HardRule), got {result:?}"
);
}
#[test]
fn project_invariant_rejected_by_validate() {
let candidate = LlmCandidate {
schema: SCHEMA.into(),
candidate_id: "bad-cand-2".into(),
input_source_ids: vec!["src_abc".into()],
claims: vec![LlmClaim {
text: "CI must never be skipped".into(),
claim_type: ClaimType::ProjectInvariant,
source_ids: vec!["src_abc".into()],
confidence: ClaimConfidence::Candidate,
}],
};
let result = crate::validator::validate(&candidate, &["src_abc".into()]);
assert!(
matches!(
result,
Err(LlmError::ForbiddenClaimType(ClaimType::ProjectInvariant))
),
"expected ForbiddenClaimType(ProjectInvariant), got {result:?}"
);
}
#[test]
fn missing_source_id_rejected_by_validate() {
let candidate = valid_candidate(vec!["ghost_source".into()]);
let result = crate::validator::validate(&candidate, &["real_source".into()]);
assert!(
matches!(result, Err(LlmError::SourceIdMissing(_))),
"expected SourceIdMissing, got {result:?}"
);
}
#[test]
fn is_available_returns_false_for_unreachable_host() {
let adapter = OllamaAdapter::new("http://127.0.0.1:19999", "qwen2.5-coder:14b");
let avail = adapter.is_available();
assert!(!avail);
}
#[test]
fn build_client_has_timeouts() {
let client = build_client(GENERATE_TIMEOUT_SECS).expect("build client");
let start = std::time::Instant::now();
let _ = client.get("http://127.0.0.1:19998/api/tags").send();
let elapsed = start.elapsed();
assert!(
elapsed < Duration::from_secs(30),
"build_client should fail-fast on closed port, took {elapsed:?}"
);
}
#[test]
fn generate_with_grounding_rejects_ungrounded_via_validator() {
let candidate = LlmCandidate {
schema: SCHEMA.into(),
candidate_id: "ungrounded".into(),
input_source_ids: vec!["src_abc".into()],
claims: vec![LlmClaim {
text: "the project enforces deterministic builds across all hosts".into(),
claim_type: ClaimType::CandidateObservation,
source_ids: vec!["src_abc".into()],
confidence: ClaimConfidence::Candidate,
}],
};
let mut sources = HashMap::new();
sources.insert(
"src_abc".to_string(),
"this file has totally unrelated prose".to_string(),
);
let result =
crate::validator::validate_with_sources(&candidate, &["src_abc".into()], &sources);
assert!(
matches!(result, Err(LlmError::SourceNotGrounded { .. })),
"expected SourceNotGrounded, got {result:?}"
);
}
#[test]
fn generate_with_grounding_accepts_grounded_via_validator() {
let candidate = LlmCandidate {
schema: SCHEMA.into(),
candidate_id: "grounded".into(),
input_source_ids: vec!["src_abc".into()],
claims: vec![LlmClaim {
text: "deterministic builds across all hosts are required".into(),
claim_type: ClaimType::CandidateObservation,
source_ids: vec!["src_abc".into()],
confidence: ClaimConfidence::Candidate,
}],
};
let mut sources = HashMap::new();
sources.insert(
"src_abc".to_string(),
"We need deterministic builds across all hosts to ship safely.".to_string(),
);
let result =
crate::validator::validate_with_sources(&candidate, &["src_abc".into()], &sources);
assert!(result.is_ok(), "expected Ok, got {result:?}");
}
}