use anyhow::{anyhow, Result};
use super::agent_model_runs::{ModelAnswer, ModelCaller};
#[cfg(feature = "gen-candle")]
pub mod candle;
#[cfg(feature = "gen-mistralrs")]
pub mod mistralrs;
#[cfg(feature = "gen-onnx")]
pub mod onnx;
#[cfg(feature = "gen-ollama")]
pub mod ollama;
pub const GEN_BACKEND_ENV: &str = "NORNIR_GEN_BACKEND";
#[derive(Debug, Clone, PartialEq)]
pub struct GenRequest {
pub prompt: String,
pub system: Option<String>,
pub max_tokens: usize,
pub temperature: f32,
pub stop: Vec<String>,
}
impl GenRequest {
pub fn new(prompt: impl Into<String>) -> Self {
GenRequest {
prompt: prompt.into(),
system: None,
max_tokens: 256,
temperature: 0.0,
stop: Vec::new(),
}
}
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
pub fn with_max_tokens(mut self, n: usize) -> Self {
self.max_tokens = n;
self
}
pub fn with_temperature(mut self, t: f32) -> Self {
self.temperature = t;
self
}
pub fn with_stop(mut self, s: impl Into<String>) -> Self {
self.stop.push(s.into());
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct GenAnswer {
pub text: String,
pub tokens_in: i64,
pub tokens_out: i64,
pub tokens_per_s: f64,
pub latency_ms: f64,
}
impl GenAnswer {
pub fn into_model_answer(self) -> ModelAnswer {
ModelAnswer {
output: self.text,
latency_ms: self.latency_ms,
tokens_in: self.tokens_in,
tokens_out: self.tokens_out,
tokens_per_s: self.tokens_per_s,
score: 0.0,
cost_usd: 0.0,
mcp_tool_calls: 0,
}
}
}
pub trait Backend {
fn id(&self) -> &str;
fn available(&self) -> bool;
}
pub trait Generator: Backend + Send + Sync {
fn complete(&self, req: &GenRequest) -> Result<GenAnswer>;
}
pub struct GeneratorCaller<G: Generator> {
gen: G,
}
impl<G: Generator> GeneratorCaller<G> {
pub fn new(gen: G) -> Self {
Self { gen }
}
}
impl<G: Generator> ModelCaller for GeneratorCaller<G> {
fn call(&self, _agent: &str, _model: &str, prompt: &str) -> Result<ModelAnswer> {
let answer = self.gen.complete(&GenRequest::new(prompt))?;
Ok(answer.into_model_answer())
}
}
pub struct BoxGeneratorCaller {
gen: Box<dyn Generator>,
}
impl BoxGeneratorCaller {
pub fn new(gen: Box<dyn Generator>) -> Self {
Self { gen }
}
pub fn id(&self) -> &str {
self.gen.id()
}
pub fn available(&self) -> bool {
self.gen.available()
}
}
impl ModelCaller for BoxGeneratorCaller {
fn call(&self, _agent: &str, _model: &str, prompt: &str) -> Result<ModelAnswer> {
let answer = self.gen.complete(&GenRequest::new(prompt))?;
Ok(answer.into_model_answer())
}
}
pub fn spec_from_env() -> Option<String> {
std::env::var(GEN_BACKEND_ENV).ok().filter(|s| !s.trim().is_empty())
}
pub fn generator(spec: &str) -> Result<Box<dyn Generator>> {
let spec = spec.trim();
let (backend, model) = match spec.split_once(':') {
Some((b, m)) => (b.trim(), m.trim()),
None => (spec, ""),
};
match backend {
"mock" => Ok(Box::new(MockGenerator::new(if model.is_empty() {
"mock"
} else {
model
}))),
"candle" => {
#[cfg(feature = "gen-candle")]
{
Ok(Box::new(candle::CandleGenerator::new(model)?))
}
#[cfg(not(feature = "gen-candle"))]
{
Err(not_compiled("candle", "gen-candle"))
}
}
"mistralrs" => {
#[cfg(feature = "gen-mistralrs")]
{
Ok(Box::new(mistralrs::MistralRsGenerator::new(model)?))
}
#[cfg(not(feature = "gen-mistralrs"))]
{
Err(not_compiled("mistralrs", "gen-mistralrs"))
}
}
"onnx" => {
#[cfg(feature = "gen-onnx")]
{
Ok(Box::new(onnx::OnnxGenerator::new(model)?))
}
#[cfg(not(feature = "gen-onnx"))]
{
Err(not_compiled("onnx", "gen-onnx"))
}
}
"ollama" => {
#[cfg(feature = "gen-ollama")]
{
Ok(Box::new(ollama::OllamaGenerator::new(
model,
None,
)))
}
#[cfg(not(feature = "gen-ollama"))]
{
Err(not_compiled("ollama", "gen-ollama"))
}
}
other => Err(anyhow!(
"unknown generator backend `{other}` in spec `{spec}` — \
expected one of candle:<m> | mistralrs:<m> | onnx:<m> | ollama:<m> | mock"
)),
}
}
#[allow(dead_code)] fn not_compiled(backend: &str, feature: &str) -> anyhow::Error {
anyhow!(
"generator backend `{backend}` is not compiled in — \
rebuild with `--features {feature}`"
)
}
#[derive(Debug, Clone)]
pub struct MockGenerator {
id: String,
}
impl MockGenerator {
pub fn new(id: impl Into<String>) -> Self {
Self { id: id.into() }
}
fn answer_text(req: &GenRequest) -> String {
let sys = req.system.as_deref().unwrap_or("");
format!("mock[{}]: {}{}", req.max_tokens, sys, req.prompt)
}
}
impl Backend for MockGenerator {
fn id(&self) -> &str {
&self.id
}
fn available(&self) -> bool {
true
}
}
impl Generator for MockGenerator {
fn complete(&self, req: &GenRequest) -> Result<GenAnswer> {
let text = Self::answer_text(req);
let tokens_in = req.prompt.split_whitespace().count() as i64;
let tokens_out = text.split_whitespace().count() as i64;
let latency_ms = 10.0;
let tokens_per_s = tokens_out as f64 / (latency_ms / 1000.0);
Ok(GenAnswer {
text,
tokens_in,
tokens_out,
tokens_per_s,
latency_ms,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mock_round_trips_request_into_answer() {
let gen = generator("mock").unwrap();
assert_eq!(gen.id(), "mock");
assert!(gen.available(), "mock is always available");
let req = GenRequest::new("capital of France?")
.with_system("be terse")
.with_max_tokens(64);
let ans = gen.complete(&req).unwrap();
assert!(ans.text.contains("capital of France?"), "echoes prompt: {}", ans.text);
assert!(ans.text.contains("be terse"), "echoes system: {}", ans.text);
assert!(ans.text.contains("64"), "echoes max_tokens: {}", ans.text);
assert_eq!(ans.tokens_in, 3, "3 prompt words");
assert!(ans.tokens_out > 0);
assert!(ans.tokens_per_s > 0.0);
assert!((ans.latency_ms - 10.0).abs() < 1e-9);
}
#[test]
fn mock_with_explicit_id_keeps_the_id() {
let gen = generator("mock:tiny").unwrap();
assert_eq!(gen.id(), "tiny");
}
#[test]
fn factory_errors_on_unknown_backend() {
let err = match generator("wat:model") {
Ok(_) => panic!("unknown backend must error"),
Err(e) => e.to_string(),
};
assert!(err.contains("unknown generator backend"), "{err}");
assert!(err.contains("wat"), "{err}");
}
fn assert_backend_arm(spec: &str, compiled: bool) {
let res = generator(spec);
match res {
Ok(_) => assert!(compiled, "{spec} produced a generator but its feature is off"),
Err(e) => {
let s = e.to_string();
if compiled {
assert!(
!s.contains("unknown generator backend") && !s.contains("not compiled"),
"{spec} is compiled — error must be model-level, got: {s}"
);
} else {
assert!(
s.contains("not compiled") || s.contains("rebuild with"),
"{spec} feature is off — expected 'not compiled', got: {s}"
);
}
}
}
}
#[test]
fn factory_reports_uncompiled_backends() {
assert_backend_arm("candle:m", cfg!(feature = "gen-candle"));
assert_backend_arm("mistralrs:m", cfg!(feature = "gen-mistralrs"));
assert_backend_arm("onnx:m", cfg!(feature = "gen-onnx"));
match generator("nope:m") {
Ok(_) => panic!("`nope` is not a backend"),
Err(e) => assert!(
e.to_string().contains("unknown generator backend"),
"{}",
e
),
}
}
#[test]
fn generator_caller_bridges_to_modelcaller() {
let gen = MockGenerator::new("mock");
let caller = GeneratorCaller::new(gen);
let ans = caller.call("local-llm", "mock", "2+2?").unwrap();
assert!(ans.output.contains("2+2?"));
assert_eq!(ans.cost_usd, 0.0, "local generation is free");
assert_eq!(ans.mcp_tool_calls, 0);
assert!(ans.tokens_out > 0);
}
#[test]
fn spec_from_env_reads_the_var() {
let prev = std::env::var(GEN_BACKEND_ENV).ok();
std::env::set_var(GEN_BACKEND_ENV, "mock:envpick");
assert_eq!(spec_from_env().as_deref(), Some("mock:envpick"));
std::env::set_var(GEN_BACKEND_ENV, " ");
assert_eq!(spec_from_env(), None, "blank is treated as unset");
match prev {
Some(v) => std::env::set_var(GEN_BACKEND_ENV, v),
None => std::env::remove_var(GEN_BACKEND_ENV),
}
}
}