use async_trait::async_trait;
use ranvier_core::bus::Bus;
use ranvier_core::outcome::Outcome;
use ranvier_core::transition::Transition;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum LlmProvider {
Mock,
Claude,
OpenAI,
Custom(String),
}
impl fmt::Display for LlmProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LlmProvider::Mock => write!(f, "mock"),
LlmProvider::Claude => write!(f, "claude"),
LlmProvider::OpenAI => write!(f, "openai"),
LlmProvider::Custom(name) => write!(f, "custom:{name}"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LlmError {
ProviderUnavailable {
provider: String,
reason: String,
},
TemplateMissing {
variable: String,
},
RequestFailed {
provider: String,
attempts: u32,
last_error: String,
},
SchemaValidation {
expected_schema: serde_json::Value,
raw_response: String,
reason: String,
},
ResponseParse {
raw_response: String,
reason: String,
},
}
impl fmt::Display for LlmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LlmError::ProviderUnavailable { provider, reason } => {
write!(f, "LLM provider `{provider}` unavailable: {reason}")
}
LlmError::TemplateMissing { variable } => {
write!(f, "template variable `{variable}` not found on Bus")
}
LlmError::RequestFailed {
provider,
attempts,
last_error,
} => {
write!(
f,
"LLM request to `{provider}` failed after {attempts} attempt(s): {last_error}"
)
}
LlmError::SchemaValidation {
reason,
raw_response,
..
} => {
write!(
f,
"LLM response schema validation failed: {reason} (response: {raw_response})"
)
}
LlmError::ResponseParse {
raw_response,
reason,
} => {
write!(
f,
"failed to parse LLM response as JSON: {reason} (response: {raw_response})"
)
}
}
}
}
impl std::error::Error for LlmError {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MockLlmConfig {
pub response: String,
pub should_fail: bool,
pub failure_message: String,
}
impl Default for MockLlmConfig {
fn default() -> Self {
Self {
response: r#"{"result":"mock_response"}"#.to_string(),
should_fail: false,
failure_message: "simulated mock failure".to_string(),
}
}
}
#[derive(Clone)]
pub struct LlmTransition {
provider: LlmProvider,
model: Option<String>,
system_prompt: Option<String>,
prompt_template: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
retry_count: u32,
output_schema: Option<serde_json::Value>,
label_override: Option<String>,
}
impl fmt::Debug for LlmTransition {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LlmTransition")
.field("provider", &self.provider)
.field("model", &self.model)
.field("max_tokens", &self.max_tokens)
.field("temperature", &self.temperature)
.field("retry_count", &self.retry_count)
.field("has_output_schema", &self.output_schema.is_some())
.finish()
}
}
impl LlmTransition {
pub fn new(provider: LlmProvider) -> Self {
Self {
provider,
model: None,
system_prompt: None,
prompt_template: None,
max_tokens: None,
temperature: None,
retry_count: 0,
output_schema: None,
label_override: None,
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn system_prompt(mut self, system: impl Into<String>) -> Self {
self.system_prompt = Some(system.into());
self
}
pub fn prompt_template(mut self, template: impl Into<String>) -> Self {
self.prompt_template = Some(template.into());
self
}
pub fn max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn retry_count(mut self, count: u32) -> Self {
self.retry_count = count;
self
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.label_override = Some(label.into());
self
}
pub fn output_schema<T: Serialize + for<'de> Deserialize<'de> + Default>(mut self) -> Self {
let sample = T::default();
if let Ok(value) = serde_json::to_value(&sample) {
self.output_schema = Some(infer_schema_from_value(&value));
}
self
}
pub fn output_schema_raw(mut self, schema: serde_json::Value) -> Self {
self.output_schema = Some(schema);
self
}
fn render_prompt(&self, template: &str, bus: &Bus) -> Result<String, LlmError> {
let vars = bus.read::<LlmTemplateVars>();
let mut result = template.to_string();
let json_re = "{{json:";
while let Some(start) = result.find(json_re) {
let after = start + json_re.len();
let end = result[after..]
.find("}}")
.map(|i| after + i)
.ok_or_else(|| LlmError::TemplateMissing {
variable: result[after..].to_string(),
})?;
let var_name = &result[after..end];
let value = vars
.and_then(|v| v.get(var_name))
.ok_or_else(|| LlmError::TemplateMissing {
variable: var_name.to_string(),
})?;
let json_str = serde_json::to_string(value).unwrap_or_default();
result.replace_range(start..end + 2, &json_str);
}
let simple_re = "{{";
while let Some(start) = result.find(simple_re) {
let after = start + simple_re.len();
let end = result[after..]
.find("}}")
.map(|i| after + i)
.ok_or_else(|| LlmError::TemplateMissing {
variable: result[after..].to_string(),
})?;
let var_name = &result[after..end];
let value = vars
.and_then(|v| v.get(var_name))
.ok_or_else(|| LlmError::TemplateMissing {
variable: var_name.to_string(),
})?;
let plain_str = match value {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
result.replace_range(start..end + 2, &plain_str);
}
Ok(result)
}
fn validate_response(&self, raw: &str) -> Result<(), LlmError> {
let Some(schema) = &self.output_schema else {
return Ok(());
};
let parsed: serde_json::Value =
serde_json::from_str(raw).map_err(|e| LlmError::ResponseParse {
raw_response: raw.to_string(),
reason: e.to_string(),
})?;
validate_value_against_schema(&parsed, schema).map_err(|reason| {
LlmError::SchemaValidation {
expected_schema: schema.clone(),
raw_response: raw.to_string(),
reason,
}
})
}
async fn call_provider(&self, prompt: &str) -> Result<String, String> {
match &self.provider {
LlmProvider::Mock => self.call_mock(prompt),
LlmProvider::Claude => {
Err("Claude provider requires feature `llm-claude` (not yet implemented)".into())
}
LlmProvider::OpenAI => {
Err("OpenAI provider requires feature `llm-openai` (not yet implemented)".into())
}
LlmProvider::Custom(name) => Err(format!(
"Custom provider `{name}` has no built-in implementation; \
use a custom Transition instead"
)),
}
}
fn call_mock(&self, _prompt: &str) -> Result<String, String> {
Ok(MockLlmConfig::default().response)
}
fn call_mock_with_config(&self, _prompt: &str, config: &MockLlmConfig) -> Result<String, String> {
if config.should_fail {
Err(config.failure_message.clone())
} else {
Ok(config.response.clone())
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct LlmTemplateVars {
inner: serde_json::Map<String, serde_json::Value>,
}
impl LlmTemplateVars {
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) -> &mut Self {
self.inner.insert(key.into(), value);
self
}
pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
self.inner.get(key)
}
pub fn contains(&self, key: &str) -> bool {
self.inner.contains_key(key)
}
pub fn iter(&self) -> serde_json::map::Iter<'_> {
self.inner.iter()
}
}
#[async_trait]
impl Transition<String, String> for LlmTransition {
type Error = LlmError;
type Resources = ();
fn label(&self) -> String {
self.label_override
.clone()
.unwrap_or_else(|| format!("LLM:{}", self.provider))
}
fn description(&self) -> Option<String> {
let model = self.model.as_deref().unwrap_or("default");
Some(format!(
"LLM call via {} (model={model}, max_tokens={}, temp={})",
self.provider,
self.max_tokens.unwrap_or(0),
self.temperature.unwrap_or(1.0),
))
}
async fn run(
&self,
input: String,
_resources: &Self::Resources,
bus: &mut Bus,
) -> Outcome<String, Self::Error> {
let prompt = if input.is_empty() {
match &self.prompt_template {
Some(tpl) => match self.render_prompt(tpl, bus) {
Ok(rendered) => rendered,
Err(e) => return Outcome::Fault(e),
},
None => {
return Outcome::Fault(LlmError::TemplateMissing {
variable: "(no prompt template or input provided)".into(),
});
}
}
} else if self.prompt_template.is_some() {
let tpl = self.prompt_template.as_ref().expect("prompt_template guaranteed by is_some() guard");
let with_input = tpl.replace("{{input}}", &input);
match self.render_prompt(&with_input, bus) {
Ok(rendered) => rendered,
Err(e) => return Outcome::Fault(e),
}
} else {
input
};
let full_prompt = match &self.system_prompt {
Some(sys) => format!("[system]\n{sys}\n\n[user]\n{prompt}"),
None => prompt,
};
tracing::debug!(
provider = %self.provider,
model = ?self.model,
prompt_len = full_prompt.len(),
"LlmTransition executing"
);
let max_attempts = self.retry_count + 1;
let mut last_error = String::new();
for attempt in 1..=max_attempts {
let result = match &self.provider {
LlmProvider::Mock => {
match bus.read::<MockLlmConfig>() {
Some(cfg) => self.call_mock_with_config(&full_prompt, cfg),
None => self.call_mock(&full_prompt),
}
}
_ => self.call_provider(&full_prompt).await,
};
match result {
Ok(response) => {
if let Err(e) = self.validate_response(&response) {
tracing::warn!(
attempt,
provider = %self.provider,
"LLM response failed schema validation"
);
return Outcome::Fault(e);
}
tracing::debug!(
attempt,
provider = %self.provider,
response_len = response.len(),
"LlmTransition completed"
);
return Outcome::Next(response);
}
Err(err) => {
tracing::warn!(
attempt,
max_attempts,
provider = %self.provider,
error = %err,
"LLM call failed"
);
last_error = err;
}
}
}
Outcome::Fault(LlmError::RequestFailed {
provider: self.provider.to_string(),
attempts: max_attempts,
last_error,
})
}
}
fn infer_schema_from_value(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Object(map) => {
let mut properties = serde_json::Map::new();
for (key, val) in map {
properties.insert(key.clone(), infer_schema_from_value(val));
}
serde_json::json!({
"type": "object",
"properties": properties
})
}
serde_json::Value::Array(arr) => {
let items = arr
.first()
.map(infer_schema_from_value)
.unwrap_or_else(|| serde_json::json!({}));
serde_json::json!({
"type": "array",
"items": items
})
}
serde_json::Value::String(_) => serde_json::json!({"type": "string"}),
serde_json::Value::Number(_) => serde_json::json!({"type": "number"}),
serde_json::Value::Bool(_) => serde_json::json!({"type": "boolean"}),
serde_json::Value::Null => serde_json::json!({"type": "null"}),
}
}
fn validate_value_against_schema(
value: &serde_json::Value,
schema: &serde_json::Value,
) -> Result<(), String> {
let Some(expected_type) = schema.get("type").and_then(|t| t.as_str()) else {
return Ok(());
};
let actual_type = json_type_name(value);
if actual_type != expected_type {
return Err(format!(
"expected type `{expected_type}`, got `{actual_type}`"
));
}
if expected_type == "object" {
if let (Some(props), Some(obj)) = (
schema.get("properties").and_then(|p| p.as_object()),
value.as_object(),
) {
for (key, prop_schema) in props {
match obj.get(key) {
Some(val) => validate_value_against_schema(val, prop_schema)
.map_err(|e| format!("property `{key}`: {e}"))?,
None => {
}
}
}
}
}
if expected_type == "array" {
if let (Some(items_schema), Some(arr)) = (schema.get("items"), value.as_array()) {
for (i, elem) in arr.iter().enumerate() {
validate_value_against_schema(elem, items_schema)
.map_err(|e| format!("item[{i}]: {e}"))?;
}
}
}
Ok(())
}
fn json_type_name(value: &serde_json::Value) -> &'static str {
match value {
serde_json::Value::Null => "null",
serde_json::Value::Bool(_) => "boolean",
serde_json::Value::Number(_) => "number",
serde_json::Value::String(_) => "string",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_sets_all_fields() {
let t = LlmTransition::new(LlmProvider::Claude)
.model("claude-sonnet-4-5-20250929")
.system_prompt("You are a moderator.")
.prompt_template("Classify: {{content}}")
.max_tokens(200)
.temperature(0.3)
.retry_count(2)
.with_label("ModerationLLM");
assert_eq!(t.provider, LlmProvider::Claude);
assert_eq!(t.model.as_deref(), Some("claude-sonnet-4-5-20250929"));
assert_eq!(t.system_prompt.as_deref(), Some("You are a moderator."));
assert_eq!(
t.prompt_template.as_deref(),
Some("Classify: {{content}}")
);
assert_eq!(t.max_tokens, Some(200));
assert_eq!(t.temperature, Some(0.3));
assert_eq!(t.retry_count, 2);
assert_eq!(t.label(), "ModerationLLM");
}
#[test]
fn default_label_includes_provider() {
let t = LlmTransition::new(LlmProvider::OpenAI);
assert_eq!(t.label(), "LLM:openai");
}
#[test]
fn template_rendering_simple() {
let t = LlmTransition::new(LlmProvider::Mock)
.prompt_template("Hello, {{name}}!");
let mut bus = Bus::new();
let mut vars = LlmTemplateVars::new();
vars.set("name", serde_json::json!("Alice"));
bus.provide(vars);
let rendered = t.render_prompt("Hello, {{name}}!", &bus).unwrap();
assert_eq!(rendered, "Hello, Alice!");
}
#[test]
fn template_rendering_json_var() {
let t = LlmTransition::new(LlmProvider::Mock);
let mut bus = Bus::new();
let mut vars = LlmTemplateVars::new();
vars.set("data", serde_json::json!({"key": "value"}));
bus.provide(vars);
let rendered = t
.render_prompt("Payload: {{json:data}}", &bus)
.unwrap();
assert_eq!(rendered, r#"Payload: {"key":"value"}"#);
}
#[test]
fn template_missing_variable_returns_error() {
let t = LlmTransition::new(LlmProvider::Mock);
let bus = Bus::new();
let err = t
.render_prompt("Hello, {{missing}}!", &bus)
.unwrap_err();
assert!(matches!(err, LlmError::TemplateMissing { variable } if variable == "missing"));
}
#[tokio::test]
async fn mock_provider_returns_default_response() {
let t = LlmTransition::new(LlmProvider::Mock)
.prompt_template("test prompt");
let mut bus = Bus::new();
let mut vars = LlmTemplateVars::new();
vars.set("_placeholder", serde_json::json!(true));
bus.provide(vars);
let outcome = t.run(String::new(), &(), &mut bus).await;
match outcome {
Outcome::Next(response) => {
assert!(response.contains("mock_response"));
}
other => panic!("expected Outcome::Next, got {other:?}"),
}
}
#[tokio::test]
async fn mock_provider_with_custom_config() {
let t = LlmTransition::new(LlmProvider::Mock);
let mut bus = Bus::new();
bus.provide(MockLlmConfig {
response: r#"{"label":"safe"}"#.to_string(),
..Default::default()
});
let outcome = t.run("direct prompt".to_string(), &(), &mut bus).await;
match outcome {
Outcome::Next(response) => {
assert_eq!(response, r#"{"label":"safe"}"#);
}
other => panic!("expected Outcome::Next, got {other:?}"),
}
}
#[tokio::test]
async fn mock_provider_failure_returns_fault() {
let t = LlmTransition::new(LlmProvider::Mock).retry_count(1);
let mut bus = Bus::new();
bus.provide(MockLlmConfig {
response: String::new(),
should_fail: true,
failure_message: "service unavailable".to_string(),
});
let outcome = t.run("test".to_string(), &(), &mut bus).await;
match outcome {
Outcome::Fault(LlmError::RequestFailed {
attempts,
last_error,
..
}) => {
assert_eq!(attempts, 2); assert_eq!(last_error, "service unavailable");
}
other => panic!("expected Outcome::Fault(RequestFailed), got {other:?}"),
}
}
#[tokio::test]
async fn schema_validation_rejects_wrong_type() {
let t = LlmTransition::new(LlmProvider::Mock)
.output_schema_raw(serde_json::json!({
"type": "object",
"properties": {
"label": {"type": "string"}
}
}));
let mut bus = Bus::new();
bus.provide(MockLlmConfig {
response: r#""just a string""#.to_string(),
..Default::default()
});
let outcome = t.run("test".to_string(), &(), &mut bus).await;
assert!(matches!(outcome, Outcome::Fault(LlmError::SchemaValidation { .. })));
}
#[tokio::test]
async fn schema_validation_accepts_valid_response() {
let t = LlmTransition::new(LlmProvider::Mock)
.output_schema_raw(serde_json::json!({
"type": "object",
"properties": {
"label": {"type": "string"},
"confidence": {"type": "number"}
}
}));
let mut bus = Bus::new();
bus.provide(MockLlmConfig {
response: r#"{"label":"safe","confidence":0.95}"#.to_string(),
..Default::default()
});
let outcome = t.run("test".to_string(), &(), &mut bus).await;
assert!(matches!(outcome, Outcome::Next(_)));
}
#[test]
fn infer_schema_from_sample_object() {
let sample = serde_json::json!({"name": "test", "count": 0});
let schema = infer_schema_from_value(&sample);
assert_eq!(schema["type"], "object");
assert_eq!(schema["properties"]["name"]["type"], "string");
assert_eq!(schema["properties"]["count"]["type"], "number");
}
#[test]
fn provider_display() {
assert_eq!(LlmProvider::Mock.to_string(), "mock");
assert_eq!(LlmProvider::Claude.to_string(), "claude");
assert_eq!(LlmProvider::OpenAI.to_string(), "openai");
assert_eq!(
LlmProvider::Custom("ollama".into()).to_string(),
"custom:ollama"
);
}
#[test]
fn llm_error_display_coverage() {
let err = LlmError::ProviderUnavailable {
provider: "claude".into(),
reason: "feature not enabled".into(),
};
assert!(err.to_string().contains("claude"));
let err = LlmError::TemplateMissing {
variable: "foo".into(),
};
assert!(err.to_string().contains("foo"));
let err = LlmError::RequestFailed {
provider: "openai".into(),
attempts: 3,
last_error: "timeout".into(),
};
assert!(err.to_string().contains("3 attempt(s)"));
let err = LlmError::ResponseParse {
raw_response: "not json".into(),
reason: "unexpected token".into(),
};
assert!(err.to_string().contains("unexpected token"));
}
#[test]
fn template_vars_api() {
let mut vars = LlmTemplateVars::new();
vars.set("key1", serde_json::json!("value1"));
vars.set("key2", serde_json::json!(42));
assert!(vars.contains("key1"));
assert!(!vars.contains("key3"));
assert_eq!(vars.get("key1").unwrap(), &serde_json::json!("value1"));
assert_eq!(vars.iter().count(), 2);
}
#[tokio::test]
async fn claude_provider_returns_fault_without_feature() {
let t = LlmTransition::new(LlmProvider::Claude);
let mut bus = Bus::new();
let outcome = t.run("test".to_string(), &(), &mut bus).await;
assert!(matches!(
outcome,
Outcome::Fault(LlmError::RequestFailed { .. })
));
}
#[test]
fn description_includes_model_and_params() {
let t = LlmTransition::new(LlmProvider::Claude)
.model("claude-sonnet-4-5-20250929")
.max_tokens(200)
.temperature(0.3);
let desc = t.description().unwrap();
assert!(desc.contains("claude"));
assert!(desc.contains("claude-sonnet-4-5-20250929"));
assert!(desc.contains("200"));
}
}