1use crate::model::{AgentCfg, QaNode};
2use anyhow::{Result, bail};
3use gsm_core::MessageEnvelope;
4use regex::Regex;
5use serde_json::{Value, json};
6
7pub async fn run_qa(
8 cfg: &QaNode,
9 env: &MessageEnvelope,
10 state: &mut Value,
11 _hbs: &handlebars::Handlebars<'static>,
12) -> Result<()> {
13 run_qa_inner(cfg, env, state, true).await
14}
15
16#[allow(dead_code)]
17pub async fn run_qa_offline(cfg: &QaNode, env: &MessageEnvelope, state: &mut Value) -> Result<()> {
18 run_qa_inner(cfg, env, state, false).await
19}
20
21async fn run_qa_inner(
22 cfg: &QaNode,
23 env: &MessageEnvelope,
24 state: &mut Value,
25 allow_agent: bool,
26) -> Result<()> {
27 if !state.is_object() {
28 *state = json!({});
29 }
30 let obj = state.as_object_mut().unwrap();
31
32 for q in &cfg.questions {
34 if let Some(def) = &q.default
35 && !obj.contains_key(&q.id)
36 {
37 obj.insert(q.id.clone(), def.clone());
38 }
39 }
40
41 let mut missing: Vec<&str> = cfg
42 .questions
43 .iter()
44 .filter(|q| !obj.contains_key(&q.id))
45 .map(|q| q.id.as_str())
46 .collect();
47
48 if !missing.is_empty()
49 && let Some(text) = &env.text
50 {
51 let number_re = Regex::new(r"(?P<n>\d+)").unwrap();
52 for q in &cfg.questions {
53 if missing.contains(&q.id.as_str()) {
54 match q.answer_type.as_deref() {
55 Some("number") => {
56 if let Some(caps) = number_re.captures(text)
57 && let Some(m) = caps.name("n")
58 {
59 obj.insert(q.id.clone(), json!(m.as_str().parse::<i64>().unwrap_or(1)));
60 }
61 }
62 _ => {
63 let loc = text
64 .split_whitespace()
65 .take(q.max_words.unwrap_or(3))
66 .collect::<Vec<_>>()
67 .join(" ");
68 if !loc.is_empty() {
69 obj.insert(q.id.clone(), json!(loc));
70 }
71 }
72 }
73 }
74 }
75 missing = cfg
76 .questions
77 .iter()
78 .filter(|q| !obj.contains_key(&q.id))
79 .map(|q| q.id.as_str())
80 .collect();
81 }
82
83 if !missing.is_empty()
84 && let Some(agent) = &cfg.fallback_agent
85 {
86 if !allow_agent {
87 bail!("qa fallback agent requires network; offline execution not permitted");
88 }
89 let extracted = call_agent(agent, env).await?;
90 for (k, v) in extracted.as_object().unwrap_or(&serde_json::Map::new()) {
91 obj.insert(k.clone(), v.clone());
92 }
93 }
94
95 for q in &cfg.questions {
97 if let Some(val) = obj.get(&q.id).cloned() {
98 if let Some(r) = q.validate.as_ref().and_then(|v| v.range) {
99 let n = val
100 .as_f64()
101 .or_else(|| val.as_i64().map(|x| x as f64))
102 .unwrap_or(0.0);
103 let clamped = n.clamp(r[0], r[1]);
104 obj.insert(q.id.clone(), json!(clamped));
105 }
106 if let Some(maxw) = q.max_words {
107 let s = val.as_str().unwrap_or_default();
108 if s.split_whitespace().count() > maxw {
109 bail!("answer '{}' exceeds max_words {}", q.id, maxw);
110 }
111 }
112 }
113 }
114 Ok(())
115}
116
117async fn call_agent(agent: &AgentCfg, env: &MessageEnvelope) -> Result<serde_json::Value> {
118 let url = agent
119 .endpoint
120 .clone()
121 .unwrap_or_else(|| "http://localhost:18080/agent/extract".into());
122 let body = json!({
123 "type": agent.r#type.as_deref().unwrap_or("ollama"),
124 "model": agent.model.as_deref().unwrap_or("gemma:instruct"),
125 "task": agent.task.as_deref().unwrap_or("extract keys"),
126 "text": env.text,
127 });
128 let resp = reqwest::Client::new().post(url).json(&body).send().await?;
129 let v: serde_json::Value = resp.json().await.unwrap_or_else(|_| json!({}));
130 Ok(v.get("extracted").cloned().unwrap_or_else(|| json!({})))
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::model::{QaNode, Question, Validate};
137 use handlebars::Handlebars;
138
139 fn handlebars() -> &'static Handlebars<'static> {
140 Box::leak(Box::new(Handlebars::new()))
141 }
142
143 fn envelope_with_text(text: Option<&str>) -> MessageEnvelope {
144 MessageEnvelope {
145 tenant: "acme".into(),
146 platform: gsm_core::Platform::Slack,
147 chat_id: "C1".into(),
148 user_id: "U1".into(),
149 thread_id: None,
150 msg_id: "msg-1".into(),
151 text: text.map(|t| t.to_string()),
152 timestamp: "2024-01-01T00:00:00Z".into(),
153 context: Default::default(),
154 }
155 }
156
157 fn number_question() -> Question {
158 Question {
159 id: "quantity".into(),
160 prompt: "How many?".into(),
161 answer_type: Some("number".into()),
162 max_words: None,
163 default: None,
164 validate: Some(Validate {
165 range: Some([1.0, 10.0]),
166 }),
167 }
168 }
169
170 #[tokio::test]
171 async fn run_qa_populates_defaults_and_text() {
172 let qa = QaNode {
173 welcome: None,
174 questions: vec![
175 Question {
176 id: "name".into(),
177 prompt: "Name".into(),
178 answer_type: None,
179 max_words: Some(3),
180 default: Some(json!("guest")),
181 validate: None,
182 },
183 number_question(),
184 ],
185 fallback_agent: None,
186 };
187
188 let env = envelope_with_text(Some("I need 4 adapters"));
189 let mut state = serde_json::Value::Null;
190 run_qa(&qa, &env, &mut state, handlebars()).await.unwrap();
191
192 let obj = state.as_object().expect("state object");
193 assert_eq!(obj.get("name"), Some(&json!("guest")));
194 assert_eq!(obj.get("quantity"), Some(&json!(4.0)));
195 }
196
197 #[tokio::test]
198 async fn run_qa_errors_when_max_words_exceeded() {
199 let qa = QaNode {
200 welcome: None,
201 questions: vec![Question {
202 id: "desc".into(),
203 prompt: "Describe".into(),
204 answer_type: None,
205 max_words: Some(1),
206 default: None,
207 validate: None,
208 }],
209 fallback_agent: None,
210 };
211
212 let env = envelope_with_text(None);
213 let mut state = json!({"desc": "too many words"});
214 let err = run_qa(&qa, &env, &mut state, handlebars())
215 .await
216 .unwrap_err();
217 assert!(err.to_string().contains("max_words"));
218 }
219}