Skip to main content

jamjet_worker/executors/
model_node.rs

1//! Executor for `Model` workflow nodes.
2//!
3//! Resolves the model configuration from the workflow IR, calls the appropriate
4//! `ModelAdapter` via `ModelRegistry`, and records GenAI telemetry.
5
6use crate::executor::{ExecutionResult, NodeExecutor};
7use async_trait::async_trait;
8use jamjet_models::{ChatMessage, ModelConfig, ModelRegistry, ModelRequest};
9use jamjet_state::backend::WorkItem;
10use serde_json::{json, Value};
11use std::sync::Arc;
12use tracing::{debug, instrument};
13
14/// Executor for `model` workflow nodes.
15pub struct ModelNodeExecutor {
16    registry: Arc<ModelRegistry>,
17}
18
19impl ModelNodeExecutor {
20    pub fn new(registry: Arc<ModelRegistry>) -> Self {
21        Self { registry }
22    }
23}
24
25#[async_trait]
26impl NodeExecutor for ModelNodeExecutor {
27    #[instrument(skip(self, item), fields(node_id = %item.node_id))]
28    async fn execute(&self, item: &WorkItem) -> Result<ExecutionResult, String> {
29        let start = std::time::Instant::now();
30
31        // Extract model config from the work item payload.
32        // The payload is populated by the scheduler from the IR node definition.
33        let model = item
34            .payload
35            .get("model")
36            .and_then(|v| v.as_str())
37            .unwrap_or("claude-sonnet-4-6")
38            .to_string();
39
40        let system_prompt = item
41            .payload
42            .get("system_prompt")
43            .and_then(|v| v.as_str())
44            .map(|s| s.to_string());
45
46        let max_tokens = item
47            .payload
48            .get("max_tokens")
49            .and_then(|v| v.as_u64())
50            .map(|n| n as u32);
51
52        let temperature = item
53            .payload
54            .get("temperature")
55            .and_then(|v| v.as_f64())
56            .map(|f| f as f32);
57
58        // Build the messages from state and payload.
59        // The `prompt` field in payload may reference workflow state via template strings.
60        let prompt = item
61            .payload
62            .get("prompt")
63            .and_then(|v| v.as_str())
64            .unwrap_or("")
65            .to_string();
66
67        let mut messages = Vec::new();
68        if !prompt.is_empty() {
69            messages.push(ChatMessage::user(prompt));
70        } else {
71            // Use any explicit messages array from payload.
72            if let Some(msgs) = item.payload.get("messages").and_then(|v| v.as_array()) {
73                for msg in msgs {
74                    let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("user");
75                    let content = msg
76                        .get("content")
77                        .and_then(|c| c.as_str())
78                        .unwrap_or("")
79                        .to_string();
80                    match role {
81                        "system" => messages.push(ChatMessage::system(content)),
82                        "assistant" => messages.push(ChatMessage::assistant(content)),
83                        _ => messages.push(ChatMessage::user(content)),
84                    }
85                }
86            }
87        }
88
89        if messages.is_empty() {
90            return Err("ModelNodeExecutor: no prompt or messages provided".into());
91        }
92
93        let config = ModelConfig {
94            model: Some(model.clone()),
95            max_tokens,
96            temperature,
97            system_prompt,
98            stop_sequences: None,
99        };
100
101        debug!(model = %model, messages = messages.len(), "Calling model");
102
103        let request = ModelRequest::new(messages).with_config(config);
104        let response = self
105            .registry
106            .chat(request)
107            .await
108            .map_err(|e| format!("Model call failed: {e}"))?;
109
110        let duration_ms = start.elapsed().as_millis() as u64;
111
112        // Build output — structured or plain text.
113        let output: Value = json!({
114            "content": response.content,
115            "model": response.model,
116            "finish_reason": response.finish_reason,
117        });
118
119        // State patch: write content to `last_model_output` in workflow state.
120        let state_patch = json!({
121            "last_model_output": response.content,
122        });
123
124        Ok(ExecutionResult {
125            output,
126            state_patch,
127            duration_ms,
128            gen_ai_system: Some(
129                // Infer system from model name prefix.
130                if response.model.starts_with("claude") {
131                    "anthropic"
132                } else if response.model.starts_with("gpt") || response.model.starts_with("o1") {
133                    "openai"
134                } else {
135                    "unknown"
136                }
137                .to_string(),
138            ),
139            gen_ai_model: Some(response.model),
140            input_tokens: Some(response.input_tokens),
141            output_tokens: Some(response.output_tokens),
142            finish_reason: Some(response.finish_reason),
143        })
144    }
145}