use serde_json::{json, Value};
#[derive(Debug, Clone, Default)]
pub struct ChatArgs {
pub template: Option<Value>,
pub resource: Option<Value>,
pub ratios: Option<Vec<f32>>,
pub fallback: Option<Vec<String>>,
pub response_format: Option<Value>,
pub delay: f64,
pub name: String,
}
#[derive(Debug, Clone, Default)]
pub struct AskArgs {
pub template: Option<Value>,
pub resource: Option<Value>,
pub ratios: Option<Vec<f32>>,
pub fallback: Option<Vec<String>>,
pub fields: Vec<String>,
pub parser: String,
pub delay: f64,
pub validators: Option<Value>,
pub response_format: Option<Value>,
pub name: String,
}
pub fn build_chat_graph(args: &ChatArgs) -> Value {
let graph_name = if args.name.is_empty() {
"chat"
} else {
&args.name
};
let prompt_full = format!("{}.prompt", graph_name);
let llm_full = format!("{}.llm", graph_name);
let mut prompt_inputs = serde_json::Map::new();
if let Some(t) = &args.template {
prompt_inputs.insert("template".into(), json!({"literal": t}));
}
let mut llm_inputs = serde_json::Map::new();
llm_inputs.insert(
"messages".into(),
json!({
"required": true,
"ref": {"source": prompt_full, "var": "messages"}
}),
);
if let Some(fmt) = &args.response_format {
llm_inputs.insert("response_format".into(), json!({"literal": fmt}));
}
let mut llm_op = json!({
"type": "llm",
"name": "llm",
"full_name": llm_full,
"bound": "io",
"delay": args.delay,
"inputs": llm_inputs,
"outputs": {
"content": {"ref": {"source": "__PARENT__", "var": "content", "is_output": true}},
"role": {"ref": {"source": "__PARENT__", "var": "role", "is_output": true}},
"model_used": {"ref": {"source": "__PARENT__", "var": "model_used", "is_output": true}},
"usage": {"ref": {"source": "__PARENT__", "var": "usage", "is_output": true}},
"extras": {"ref": {"source": "__PARENT__", "var": "extras", "is_output": true}}
}
});
attach_resource(&mut llm_op, &args.resource, &args.ratios, &args.fallback);
json!({
"schema_version": "1.0",
"type": "graph",
"name": graph_name,
"full_name": graph_name,
"entries": ["prompt"],
"exits": ["llm"],
"initial_ready_count": {"prompt": 0, "llm": 1},
"compiled_adj": {
"prompt": [["llm", false]],
"llm": []
},
"inputs": {},
"outputs": {
"content": {}, "role": {}, "model_used": {}, "usage": {}, "extras": {}
},
"ops": {
"prompt": {
"type": "prompt",
"name": "prompt",
"full_name": prompt_full,
"bound": "sync",
"inputs": prompt_inputs,
"outputs": {"messages": {}}
},
"llm": llm_op
}
})
}
pub fn build_ask_graph(args: &AskArgs) -> Value {
assert!(
!args.fields.is_empty(),
"build_ask_graph: `fields` is required — matches Python's TypeError"
);
let graph_name = if args.name.is_empty() {
"ask"
} else {
&args.name
};
let prompt_full = format!("{}.prompt", graph_name);
let llm_full = format!("{}.llm", graph_name);
let parser_full = format!("{}.parser", graph_name);
let parser_fmt = if args.parser.is_empty() {
"xml"
} else {
&args.parser
};
let mut prompt_inputs = serde_json::Map::new();
if let Some(t) = &args.template {
prompt_inputs.insert("template".into(), json!({"literal": t}));
}
let mut llm_inputs = serde_json::Map::new();
llm_inputs.insert(
"messages".into(),
json!({
"required": true,
"ref": {"source": prompt_full, "var": "messages"}
}),
);
if let Some(fmt) = &args.response_format {
llm_inputs.insert("response_format".into(), json!({"literal": fmt}));
}
let mut llm_op = json!({
"type": "llm",
"name": "llm",
"full_name": llm_full,
"bound": "io",
"delay": args.delay,
"inputs": llm_inputs,
"outputs": {"content": {}}
});
attach_resource(&mut llm_op, &args.resource, &args.ratios, &args.fallback);
let mut parser_inputs = serde_json::Map::new();
parser_inputs.insert(
"text".into(),
json!({
"required": true,
"ref": {"source": llm_full, "var": "content"}
}),
);
if let Some(v) = &args.validators {
parser_inputs.insert("validators".into(), json!({"literal": v}));
}
parser_inputs.insert("format".into(), json!({"literal": parser_fmt}));
parser_inputs.insert("extract".into(), json!({"literal": args.fields}));
let parser_op = json!({
"type": "parser",
"name": "parser",
"full_name": parser_full,
"bound": "sync",
"inputs": parser_inputs,
"outputs": {"*": {"ref": {"source": "__PARENT__", "var": "*", "is_output": true}}}
});
json!({
"schema_version": "1.0",
"type": "graph",
"name": graph_name,
"full_name": graph_name,
"entries": ["prompt"],
"exits": ["parser"],
"initial_ready_count": {"prompt": 0, "llm": 1, "parser": 1},
"compiled_adj": {
"prompt": [["llm", false]],
"llm": [["parser", false]],
"parser": []
},
"inputs": {},
"outputs": {},
"ops": {
"prompt": {
"type": "prompt",
"name": "prompt",
"full_name": prompt_full,
"bound": "sync",
"inputs": prompt_inputs,
"outputs": {"messages": {}}
},
"llm": llm_op,
"parser": parser_op
}
})
}
fn attach_resource(
op: &mut Value,
resource: &Option<Value>,
ratios: &Option<Vec<f32>>,
fallback: &Option<Vec<String>>,
) {
let obj = op.as_object_mut().expect("op is an object literal");
if let Some(r) = resource {
obj.insert("resource".into(), r.clone());
}
if let Some(rt) = ratios {
obj.insert("ratios".into(), json!(rt));
}
if let Some(fb) = fallback {
obj.insert("fallback".into(), json!(fb));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::engine::GraphEnvelope;
#[test]
fn chat_graph_round_trips_through_envelope() {
let g = build_chat_graph(&ChatArgs {
template: Some(json!({"system": "ok", "user": "hi"})),
resource: Some(json!("gpt-4o")),
..Default::default()
});
let env = GraphEnvelope::parse(&g.to_string()).unwrap();
assert_eq!(env.config.name, "chat");
assert_eq!(env.config.ops.len(), 2);
assert!(env.config.ops.contains_key("prompt"));
assert!(env.config.ops.contains_key("llm"));
}
#[test]
fn ask_graph_requires_fields() {
let args = AskArgs {
template: Some(json!("q")),
..Default::default()
};
let result = std::panic::catch_unwind(|| build_ask_graph(&args));
assert!(result.is_err(), "empty `fields` must panic");
}
#[test]
fn ask_graph_round_trips() {
let g = build_ask_graph(&AskArgs {
template: Some(json!("Classify: {speech}")),
resource: Some(json!("claude-haiku")),
fields: vec!["result: str".into()],
parser: "xml".into(),
..Default::default()
});
let env = GraphEnvelope::parse(&g.to_string()).unwrap();
assert_eq!(env.config.ops.len(), 3);
assert!(env.config.ops.contains_key("parser"));
}
}