1use async_trait::async_trait;
29use ranvier_core::bus::Bus;
30use ranvier_core::outcome::Outcome;
31use ranvier_core::transition::Transition;
32use serde::{Deserialize, Serialize};
33use std::fmt;
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
44pub enum LlmProvider {
45 Mock,
47 Claude,
49 OpenAI,
51 Custom(String),
53}
54
55impl fmt::Display for LlmProvider {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 LlmProvider::Mock => write!(f, "mock"),
59 LlmProvider::Claude => write!(f, "claude"),
60 LlmProvider::OpenAI => write!(f, "openai"),
61 LlmProvider::Custom(name) => write!(f, "custom:{name}"),
62 }
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
72pub enum LlmError {
73 ProviderUnavailable {
76 provider: String,
77 reason: String,
78 },
79 TemplateMissing {
81 variable: String,
82 },
83 RequestFailed {
85 provider: String,
86 attempts: u32,
87 last_error: String,
88 },
89 SchemaValidation {
92 expected_schema: serde_json::Value,
93 raw_response: String,
94 reason: String,
95 },
96 ResponseParse {
98 raw_response: String,
99 reason: String,
100 },
101}
102
103impl fmt::Display for LlmError {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match self {
106 LlmError::ProviderUnavailable { provider, reason } => {
107 write!(f, "LLM provider `{provider}` unavailable: {reason}")
108 }
109 LlmError::TemplateMissing { variable } => {
110 write!(f, "template variable `{variable}` not found on Bus")
111 }
112 LlmError::RequestFailed {
113 provider,
114 attempts,
115 last_error,
116 } => {
117 write!(
118 f,
119 "LLM request to `{provider}` failed after {attempts} attempt(s): {last_error}"
120 )
121 }
122 LlmError::SchemaValidation {
123 reason,
124 raw_response,
125 ..
126 } => {
127 write!(
128 f,
129 "LLM response schema validation failed: {reason} (response: {raw_response})"
130 )
131 }
132 LlmError::ResponseParse {
133 raw_response,
134 reason,
135 } => {
136 write!(
137 f,
138 "failed to parse LLM response as JSON: {reason} (response: {raw_response})"
139 )
140 }
141 }
142 }
143}
144
145impl std::error::Error for LlmError {}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct MockLlmConfig {
164 pub response: String,
166 pub should_fail: bool,
168 pub failure_message: String,
170}
171
172impl Default for MockLlmConfig {
173 fn default() -> Self {
174 Self {
175 response: r#"{"result":"mock_response"}"#.to_string(),
176 should_fail: false,
177 failure_message: "simulated mock failure".to_string(),
178 }
179 }
180}
181
182#[derive(Clone)]
217pub struct LlmTransition {
218 provider: LlmProvider,
219 model: Option<String>,
220 system_prompt: Option<String>,
221 prompt_template: Option<String>,
222 max_tokens: Option<u32>,
223 temperature: Option<f32>,
224 retry_count: u32,
225 output_schema: Option<serde_json::Value>,
226 label_override: Option<String>,
227}
228
229impl fmt::Debug for LlmTransition {
230 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231 f.debug_struct("LlmTransition")
232 .field("provider", &self.provider)
233 .field("model", &self.model)
234 .field("max_tokens", &self.max_tokens)
235 .field("temperature", &self.temperature)
236 .field("retry_count", &self.retry_count)
237 .field("has_output_schema", &self.output_schema.is_some())
238 .finish()
239 }
240}
241
242impl LlmTransition {
243 pub fn new(provider: LlmProvider) -> Self {
247 Self {
248 provider,
249 model: None,
250 system_prompt: None,
251 prompt_template: None,
252 max_tokens: None,
253 temperature: None,
254 retry_count: 0,
255 output_schema: None,
256 label_override: None,
257 }
258 }
259
260 pub fn model(mut self, model: impl Into<String>) -> Self {
264 self.model = Some(model.into());
265 self
266 }
267
268 pub fn system_prompt(mut self, system: impl Into<String>) -> Self {
270 self.system_prompt = Some(system.into());
271 self
272 }
273
274 pub fn prompt_template(mut self, template: impl Into<String>) -> Self {
279 self.prompt_template = Some(template.into());
280 self
281 }
282
283 pub fn max_tokens(mut self, max: u32) -> Self {
285 self.max_tokens = Some(max);
286 self
287 }
288
289 pub fn temperature(mut self, temp: f32) -> Self {
291 self.temperature = Some(temp);
292 self
293 }
294
295 pub fn retry_count(mut self, count: u32) -> Self {
297 self.retry_count = count;
298 self
299 }
300
301 pub fn with_label(mut self, label: impl Into<String>) -> Self {
303 self.label_override = Some(label.into());
304 self
305 }
306
307 pub fn output_schema<T: Serialize + for<'de> Deserialize<'de> + Default>(mut self) -> Self {
317 let sample = T::default();
321 if let Ok(value) = serde_json::to_value(&sample) {
322 self.output_schema = Some(infer_schema_from_value(&value));
323 }
324 self
325 }
326
327 pub fn output_schema_raw(mut self, schema: serde_json::Value) -> Self {
329 self.output_schema = Some(schema);
330 self
331 }
332
333 fn render_prompt(&self, template: &str, bus: &Bus) -> Result<String, LlmError> {
339 let vars = bus.read::<LlmTemplateVars>();
340 let mut result = template.to_string();
341
342 let json_re = "{{json:";
344 while let Some(start) = result.find(json_re) {
345 let after = start + json_re.len();
346 let end = result[after..]
347 .find("}}")
348 .map(|i| after + i)
349 .ok_or_else(|| LlmError::TemplateMissing {
350 variable: result[after..].to_string(),
351 })?;
352 let var_name = &result[after..end];
353 let value = vars
354 .and_then(|v| v.get(var_name))
355 .ok_or_else(|| LlmError::TemplateMissing {
356 variable: var_name.to_string(),
357 })?;
358 let json_str = serde_json::to_string(value).unwrap_or_default();
359 result.replace_range(start..end + 2, &json_str);
360 }
361
362 let simple_re = "{{";
364 while let Some(start) = result.find(simple_re) {
365 let after = start + simple_re.len();
366 let end = result[after..]
367 .find("}}")
368 .map(|i| after + i)
369 .ok_or_else(|| LlmError::TemplateMissing {
370 variable: result[after..].to_string(),
371 })?;
372 let var_name = &result[after..end];
373 let value = vars
374 .and_then(|v| v.get(var_name))
375 .ok_or_else(|| LlmError::TemplateMissing {
376 variable: var_name.to_string(),
377 })?;
378 let plain_str = match value {
379 serde_json::Value::String(s) => s.clone(),
380 other => other.to_string(),
381 };
382 result.replace_range(start..end + 2, &plain_str);
383 }
384
385 Ok(result)
386 }
387
388 fn validate_response(&self, raw: &str) -> Result<(), LlmError> {
392 let Some(schema) = &self.output_schema else {
393 return Ok(());
394 };
395
396 let parsed: serde_json::Value =
397 serde_json::from_str(raw).map_err(|e| LlmError::ResponseParse {
398 raw_response: raw.to_string(),
399 reason: e.to_string(),
400 })?;
401
402 validate_value_against_schema(&parsed, schema).map_err(|reason| {
403 LlmError::SchemaValidation {
404 expected_schema: schema.clone(),
405 raw_response: raw.to_string(),
406 reason,
407 }
408 })
409 }
410
411 async fn call_provider(&self, prompt: &str) -> Result<String, String> {
415 match &self.provider {
416 LlmProvider::Mock => self.call_mock(prompt),
417 LlmProvider::Claude => {
418 Err("Claude provider requires feature `llm-claude` (not yet implemented)".into())
419 }
420 LlmProvider::OpenAI => {
421 Err("OpenAI provider requires feature `llm-openai` (not yet implemented)".into())
422 }
423 LlmProvider::Custom(name) => Err(format!(
424 "Custom provider `{name}` has no built-in implementation; \
425 use a custom Transition instead"
426 )),
427 }
428 }
429
430 fn call_mock(&self, _prompt: &str) -> Result<String, String> {
432 Ok(MockLlmConfig::default().response)
436 }
437
438 fn call_mock_with_config(&self, _prompt: &str, config: &MockLlmConfig) -> Result<String, String> {
440 if config.should_fail {
441 Err(config.failure_message.clone())
442 } else {
443 Ok(config.response.clone())
444 }
445 }
446}
447
448#[derive(Debug, Clone, Default, Serialize, Deserialize)]
461pub struct LlmTemplateVars {
462 inner: serde_json::Map<String, serde_json::Value>,
463}
464
465impl LlmTemplateVars {
466 pub fn new() -> Self {
468 Self::default()
469 }
470
471 pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) -> &mut Self {
473 self.inner.insert(key.into(), value);
474 self
475 }
476
477 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
479 self.inner.get(key)
480 }
481
482 pub fn contains(&self, key: &str) -> bool {
484 self.inner.contains_key(key)
485 }
486
487 pub fn iter(&self) -> serde_json::map::Iter<'_> {
489 self.inner.iter()
490 }
491}
492
493#[async_trait]
498impl Transition<String, String> for LlmTransition {
499 type Error = LlmError;
500 type Resources = ();
501
502 fn label(&self) -> String {
503 self.label_override
504 .clone()
505 .unwrap_or_else(|| format!("LLM:{}", self.provider))
506 }
507
508 fn description(&self) -> Option<String> {
509 let model = self.model.as_deref().unwrap_or("default");
510 Some(format!(
511 "LLM call via {} (model={model}, max_tokens={}, temp={})",
512 self.provider,
513 self.max_tokens.unwrap_or(0),
514 self.temperature.unwrap_or(1.0),
515 ))
516 }
517
518 async fn run(
519 &self,
520 input: String,
521 _resources: &Self::Resources,
522 bus: &mut Bus,
523 ) -> Outcome<String, Self::Error> {
524 let prompt = if input.is_empty() {
527 match &self.prompt_template {
528 Some(tpl) => match self.render_prompt(tpl, bus) {
529 Ok(rendered) => rendered,
530 Err(e) => return Outcome::Fault(e),
531 },
532 None => {
533 return Outcome::Fault(LlmError::TemplateMissing {
534 variable: "(no prompt template or input provided)".into(),
535 });
536 }
537 }
538 } else if self.prompt_template.is_some() {
539 let tpl = self.prompt_template.as_ref().expect("prompt_template guaranteed by is_some() guard");
542 let with_input = tpl.replace("{{input}}", &input);
543 match self.render_prompt(&with_input, bus) {
544 Ok(rendered) => rendered,
545 Err(e) => return Outcome::Fault(e),
546 }
547 } else {
548 input
549 };
550
551 let full_prompt = match &self.system_prompt {
553 Some(sys) => format!("[system]\n{sys}\n\n[user]\n{prompt}"),
554 None => prompt,
555 };
556
557 tracing::debug!(
558 provider = %self.provider,
559 model = ?self.model,
560 prompt_len = full_prompt.len(),
561 "LlmTransition executing"
562 );
563
564 let max_attempts = self.retry_count + 1;
566 let mut last_error = String::new();
567
568 for attempt in 1..=max_attempts {
569 let result = match &self.provider {
570 LlmProvider::Mock => {
571 match bus.read::<MockLlmConfig>() {
573 Some(cfg) => self.call_mock_with_config(&full_prompt, cfg),
574 None => self.call_mock(&full_prompt),
575 }
576 }
577 _ => self.call_provider(&full_prompt).await,
578 };
579
580 match result {
581 Ok(response) => {
582 if let Err(e) = self.validate_response(&response) {
584 tracing::warn!(
585 attempt,
586 provider = %self.provider,
587 "LLM response failed schema validation"
588 );
589 return Outcome::Fault(e);
592 }
593
594 tracing::debug!(
595 attempt,
596 provider = %self.provider,
597 response_len = response.len(),
598 "LlmTransition completed"
599 );
600 return Outcome::Next(response);
601 }
602 Err(err) => {
603 tracing::warn!(
604 attempt,
605 max_attempts,
606 provider = %self.provider,
607 error = %err,
608 "LLM call failed"
609 );
610 last_error = err;
611 }
612 }
613 }
614
615 Outcome::Fault(LlmError::RequestFailed {
616 provider: self.provider.to_string(),
617 attempts: max_attempts,
618 last_error,
619 })
620 }
621}
622
623fn infer_schema_from_value(value: &serde_json::Value) -> serde_json::Value {
633 match value {
634 serde_json::Value::Object(map) => {
635 let mut properties = serde_json::Map::new();
636 for (key, val) in map {
637 properties.insert(key.clone(), infer_schema_from_value(val));
638 }
639 serde_json::json!({
640 "type": "object",
641 "properties": properties
642 })
643 }
644 serde_json::Value::Array(arr) => {
645 let items = arr
646 .first()
647 .map(infer_schema_from_value)
648 .unwrap_or_else(|| serde_json::json!({}));
649 serde_json::json!({
650 "type": "array",
651 "items": items
652 })
653 }
654 serde_json::Value::String(_) => serde_json::json!({"type": "string"}),
655 serde_json::Value::Number(_) => serde_json::json!({"type": "number"}),
656 serde_json::Value::Bool(_) => serde_json::json!({"type": "boolean"}),
657 serde_json::Value::Null => serde_json::json!({"type": "null"}),
658 }
659}
660
661fn validate_value_against_schema(
666 value: &serde_json::Value,
667 schema: &serde_json::Value,
668) -> Result<(), String> {
669 let Some(expected_type) = schema.get("type").and_then(|t| t.as_str()) else {
670 return Ok(());
672 };
673
674 let actual_type = json_type_name(value);
675 if actual_type != expected_type {
676 return Err(format!(
677 "expected type `{expected_type}`, got `{actual_type}`"
678 ));
679 }
680
681 if expected_type == "object" {
683 if let (Some(props), Some(obj)) = (
684 schema.get("properties").and_then(|p| p.as_object()),
685 value.as_object(),
686 ) {
687 for (key, prop_schema) in props {
688 match obj.get(key) {
689 Some(val) => validate_value_against_schema(val, prop_schema)
690 .map_err(|e| format!("property `{key}`: {e}"))?,
691 None => {
692 }
695 }
696 }
697 }
698 }
699
700 if expected_type == "array" {
702 if let (Some(items_schema), Some(arr)) = (schema.get("items"), value.as_array()) {
703 for (i, elem) in arr.iter().enumerate() {
704 validate_value_against_schema(elem, items_schema)
705 .map_err(|e| format!("item[{i}]: {e}"))?;
706 }
707 }
708 }
709
710 Ok(())
711}
712
713fn json_type_name(value: &serde_json::Value) -> &'static str {
714 match value {
715 serde_json::Value::Null => "null",
716 serde_json::Value::Bool(_) => "boolean",
717 serde_json::Value::Number(_) => "number",
718 serde_json::Value::String(_) => "string",
719 serde_json::Value::Array(_) => "array",
720 serde_json::Value::Object(_) => "object",
721 }
722}
723
724#[cfg(test)]
729mod tests {
730 use super::*;
731
732 #[test]
733 fn builder_sets_all_fields() {
734 let t = LlmTransition::new(LlmProvider::Claude)
735 .model("claude-sonnet-4-5-20250929")
736 .system_prompt("You are a moderator.")
737 .prompt_template("Classify: {{content}}")
738 .max_tokens(200)
739 .temperature(0.3)
740 .retry_count(2)
741 .with_label("ModerationLLM");
742
743 assert_eq!(t.provider, LlmProvider::Claude);
744 assert_eq!(t.model.as_deref(), Some("claude-sonnet-4-5-20250929"));
745 assert_eq!(t.system_prompt.as_deref(), Some("You are a moderator."));
746 assert_eq!(
747 t.prompt_template.as_deref(),
748 Some("Classify: {{content}}")
749 );
750 assert_eq!(t.max_tokens, Some(200));
751 assert_eq!(t.temperature, Some(0.3));
752 assert_eq!(t.retry_count, 2);
753 assert_eq!(t.label(), "ModerationLLM");
754 }
755
756 #[test]
757 fn default_label_includes_provider() {
758 let t = LlmTransition::new(LlmProvider::OpenAI);
759 assert_eq!(t.label(), "LLM:openai");
760 }
761
762 #[test]
763 fn template_rendering_simple() {
764 let t = LlmTransition::new(LlmProvider::Mock)
765 .prompt_template("Hello, {{name}}!");
766 let mut bus = Bus::new();
767 let mut vars = LlmTemplateVars::new();
768 vars.set("name", serde_json::json!("Alice"));
769 bus.provide(vars);
770
771 let rendered = t.render_prompt("Hello, {{name}}!", &bus).unwrap();
772 assert_eq!(rendered, "Hello, Alice!");
773 }
774
775 #[test]
776 fn template_rendering_json_var() {
777 let t = LlmTransition::new(LlmProvider::Mock);
778 let mut bus = Bus::new();
779 let mut vars = LlmTemplateVars::new();
780 vars.set("data", serde_json::json!({"key": "value"}));
781 bus.provide(vars);
782
783 let rendered = t
784 .render_prompt("Payload: {{json:data}}", &bus)
785 .unwrap();
786 assert_eq!(rendered, r#"Payload: {"key":"value"}"#);
787 }
788
789 #[test]
790 fn template_missing_variable_returns_error() {
791 let t = LlmTransition::new(LlmProvider::Mock);
792 let bus = Bus::new();
793
794 let err = t
795 .render_prompt("Hello, {{missing}}!", &bus)
796 .unwrap_err();
797 assert!(matches!(err, LlmError::TemplateMissing { variable } if variable == "missing"));
798 }
799
800 #[tokio::test]
801 async fn mock_provider_returns_default_response() {
802 let t = LlmTransition::new(LlmProvider::Mock)
803 .prompt_template("test prompt");
804 let mut bus = Bus::new();
805 let mut vars = LlmTemplateVars::new();
806 vars.set("_placeholder", serde_json::json!(true));
807 bus.provide(vars);
808
809 let outcome = t.run(String::new(), &(), &mut bus).await;
810 match outcome {
811 Outcome::Next(response) => {
812 assert!(response.contains("mock_response"));
813 }
814 other => panic!("expected Outcome::Next, got {other:?}"),
815 }
816 }
817
818 #[tokio::test]
819 async fn mock_provider_with_custom_config() {
820 let t = LlmTransition::new(LlmProvider::Mock);
821 let mut bus = Bus::new();
822 bus.provide(MockLlmConfig {
823 response: r#"{"label":"safe"}"#.to_string(),
824 ..Default::default()
825 });
826
827 let outcome = t.run("direct prompt".to_string(), &(), &mut bus).await;
828 match outcome {
829 Outcome::Next(response) => {
830 assert_eq!(response, r#"{"label":"safe"}"#);
831 }
832 other => panic!("expected Outcome::Next, got {other:?}"),
833 }
834 }
835
836 #[tokio::test]
837 async fn mock_provider_failure_returns_fault() {
838 let t = LlmTransition::new(LlmProvider::Mock).retry_count(1);
839 let mut bus = Bus::new();
840 bus.provide(MockLlmConfig {
841 response: String::new(),
842 should_fail: true,
843 failure_message: "service unavailable".to_string(),
844 });
845
846 let outcome = t.run("test".to_string(), &(), &mut bus).await;
847 match outcome {
848 Outcome::Fault(LlmError::RequestFailed {
849 attempts,
850 last_error,
851 ..
852 }) => {
853 assert_eq!(attempts, 2); assert_eq!(last_error, "service unavailable");
855 }
856 other => panic!("expected Outcome::Fault(RequestFailed), got {other:?}"),
857 }
858 }
859
860 #[tokio::test]
861 async fn schema_validation_rejects_wrong_type() {
862 let t = LlmTransition::new(LlmProvider::Mock)
863 .output_schema_raw(serde_json::json!({
864 "type": "object",
865 "properties": {
866 "label": {"type": "string"}
867 }
868 }));
869 let mut bus = Bus::new();
870 bus.provide(MockLlmConfig {
872 response: r#""just a string""#.to_string(),
873 ..Default::default()
874 });
875
876 let outcome = t.run("test".to_string(), &(), &mut bus).await;
877 assert!(matches!(outcome, Outcome::Fault(LlmError::SchemaValidation { .. })));
878 }
879
880 #[tokio::test]
881 async fn schema_validation_accepts_valid_response() {
882 let t = LlmTransition::new(LlmProvider::Mock)
883 .output_schema_raw(serde_json::json!({
884 "type": "object",
885 "properties": {
886 "label": {"type": "string"},
887 "confidence": {"type": "number"}
888 }
889 }));
890 let mut bus = Bus::new();
891 bus.provide(MockLlmConfig {
892 response: r#"{"label":"safe","confidence":0.95}"#.to_string(),
893 ..Default::default()
894 });
895
896 let outcome = t.run("test".to_string(), &(), &mut bus).await;
897 assert!(matches!(outcome, Outcome::Next(_)));
898 }
899
900 #[test]
901 fn infer_schema_from_sample_object() {
902 let sample = serde_json::json!({"name": "test", "count": 0});
903 let schema = infer_schema_from_value(&sample);
904 assert_eq!(schema["type"], "object");
905 assert_eq!(schema["properties"]["name"]["type"], "string");
906 assert_eq!(schema["properties"]["count"]["type"], "number");
907 }
908
909 #[test]
910 fn provider_display() {
911 assert_eq!(LlmProvider::Mock.to_string(), "mock");
912 assert_eq!(LlmProvider::Claude.to_string(), "claude");
913 assert_eq!(LlmProvider::OpenAI.to_string(), "openai");
914 assert_eq!(
915 LlmProvider::Custom("ollama".into()).to_string(),
916 "custom:ollama"
917 );
918 }
919
920 #[test]
921 fn llm_error_display_coverage() {
922 let err = LlmError::ProviderUnavailable {
923 provider: "claude".into(),
924 reason: "feature not enabled".into(),
925 };
926 assert!(err.to_string().contains("claude"));
927
928 let err = LlmError::TemplateMissing {
929 variable: "foo".into(),
930 };
931 assert!(err.to_string().contains("foo"));
932
933 let err = LlmError::RequestFailed {
934 provider: "openai".into(),
935 attempts: 3,
936 last_error: "timeout".into(),
937 };
938 assert!(err.to_string().contains("3 attempt(s)"));
939
940 let err = LlmError::ResponseParse {
941 raw_response: "not json".into(),
942 reason: "unexpected token".into(),
943 };
944 assert!(err.to_string().contains("unexpected token"));
945 }
946
947 #[test]
948 fn template_vars_api() {
949 let mut vars = LlmTemplateVars::new();
950 vars.set("key1", serde_json::json!("value1"));
951 vars.set("key2", serde_json::json!(42));
952
953 assert!(vars.contains("key1"));
954 assert!(!vars.contains("key3"));
955 assert_eq!(vars.get("key1").unwrap(), &serde_json::json!("value1"));
956 assert_eq!(vars.iter().count(), 2);
957 }
958
959 #[tokio::test]
960 async fn claude_provider_returns_fault_without_feature() {
961 let t = LlmTransition::new(LlmProvider::Claude);
962 let mut bus = Bus::new();
963
964 let outcome = t.run("test".to_string(), &(), &mut bus).await;
965 assert!(matches!(
966 outcome,
967 Outcome::Fault(LlmError::RequestFailed { .. })
968 ));
969 }
970
971 #[test]
972 fn description_includes_model_and_params() {
973 let t = LlmTransition::new(LlmProvider::Claude)
974 .model("claude-sonnet-4-5-20250929")
975 .max_tokens(200)
976 .temperature(0.3);
977
978 let desc = t.description().unwrap();
979 assert!(desc.contains("claude"));
980 assert!(desc.contains("claude-sonnet-4-5-20250929"));
981 assert!(desc.contains("200"));
982 }
983}