gsm_runner/
qa_node.rs

1use crate::model::{AgentCfg, QaNode};
2use anyhow::{Result, bail};
3use gsm_core::MessageEnvelope;
4use regex::Regex;
5use serde_json::{Value, json};
6
7pub async fn run_qa(
8    cfg: &QaNode,
9    env: &MessageEnvelope,
10    state: &mut Value,
11    _hbs: &handlebars::Handlebars<'static>,
12) -> Result<()> {
13    run_qa_inner(cfg, env, state, true).await
14}
15
16#[allow(dead_code)]
17pub async fn run_qa_offline(cfg: &QaNode, env: &MessageEnvelope, state: &mut Value) -> Result<()> {
18    run_qa_inner(cfg, env, state, false).await
19}
20
21async fn run_qa_inner(
22    cfg: &QaNode,
23    env: &MessageEnvelope,
24    state: &mut Value,
25    allow_agent: bool,
26) -> Result<()> {
27    if !state.is_object() {
28        *state = json!({});
29    }
30    let obj = state.as_object_mut().unwrap();
31
32    // Defaults
33    for q in &cfg.questions {
34        if let Some(def) = &q.default
35            && !obj.contains_key(&q.id)
36        {
37            obj.insert(q.id.clone(), def.clone());
38        }
39    }
40
41    let mut missing: Vec<&str> = cfg
42        .questions
43        .iter()
44        .filter(|q| !obj.contains_key(&q.id))
45        .map(|q| q.id.as_str())
46        .collect();
47
48    if !missing.is_empty()
49        && let Some(text) = &env.text
50    {
51        let number_re = Regex::new(r"(?P<n>\d+)").unwrap();
52        for q in &cfg.questions {
53            if missing.contains(&q.id.as_str()) {
54                match q.answer_type.as_deref() {
55                    Some("number") => {
56                        if let Some(caps) = number_re.captures(text)
57                            && let Some(m) = caps.name("n")
58                        {
59                            obj.insert(q.id.clone(), json!(m.as_str().parse::<i64>().unwrap_or(1)));
60                        }
61                    }
62                    _ => {
63                        let loc = text
64                            .split_whitespace()
65                            .take(q.max_words.unwrap_or(3))
66                            .collect::<Vec<_>>()
67                            .join(" ");
68                        if !loc.is_empty() {
69                            obj.insert(q.id.clone(), json!(loc));
70                        }
71                    }
72                }
73            }
74        }
75        missing = cfg
76            .questions
77            .iter()
78            .filter(|q| !obj.contains_key(&q.id))
79            .map(|q| q.id.as_str())
80            .collect();
81    }
82
83    if !missing.is_empty()
84        && let Some(agent) = &cfg.fallback_agent
85    {
86        if !allow_agent {
87            bail!("qa fallback agent requires network; offline execution not permitted");
88        }
89        let extracted = call_agent(agent, env).await?;
90        for (k, v) in extracted.as_object().unwrap_or(&serde_json::Map::new()) {
91            obj.insert(k.clone(), v.clone());
92        }
93    }
94
95    // Validate
96    for q in &cfg.questions {
97        if let Some(val) = obj.get(&q.id).cloned() {
98            if let Some(r) = q.validate.as_ref().and_then(|v| v.range) {
99                let n = val
100                    .as_f64()
101                    .or_else(|| val.as_i64().map(|x| x as f64))
102                    .unwrap_or(0.0);
103                let clamped = n.clamp(r[0], r[1]);
104                obj.insert(q.id.clone(), json!(clamped));
105            }
106            if let Some(maxw) = q.max_words {
107                let s = val.as_str().unwrap_or_default();
108                if s.split_whitespace().count() > maxw {
109                    bail!("answer '{}' exceeds max_words {}", q.id, maxw);
110                }
111            }
112        }
113    }
114    Ok(())
115}
116
117async fn call_agent(agent: &AgentCfg, env: &MessageEnvelope) -> Result<serde_json::Value> {
118    let url = agent
119        .endpoint
120        .clone()
121        .unwrap_or_else(|| "http://localhost:18080/agent/extract".into());
122    let body = json!({
123      "type": agent.r#type.as_deref().unwrap_or("ollama"),
124      "model": agent.model.as_deref().unwrap_or("gemma:instruct"),
125      "task": agent.task.as_deref().unwrap_or("extract keys"),
126      "text": env.text,
127    });
128    let resp = reqwest::Client::new().post(url).json(&body).send().await?;
129    let v: serde_json::Value = resp.json().await.unwrap_or_else(|_| json!({}));
130    Ok(v.get("extracted").cloned().unwrap_or_else(|| json!({})))
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::model::{QaNode, Question, Validate};
137    use handlebars::Handlebars;
138
139    fn handlebars() -> &'static Handlebars<'static> {
140        Box::leak(Box::new(Handlebars::new()))
141    }
142
143    fn envelope_with_text(text: Option<&str>) -> MessageEnvelope {
144        MessageEnvelope {
145            tenant: "acme".into(),
146            platform: gsm_core::Platform::Slack,
147            chat_id: "C1".into(),
148            user_id: "U1".into(),
149            thread_id: None,
150            msg_id: "msg-1".into(),
151            text: text.map(|t| t.to_string()),
152            timestamp: "2024-01-01T00:00:00Z".into(),
153            context: Default::default(),
154        }
155    }
156
157    fn number_question() -> Question {
158        Question {
159            id: "quantity".into(),
160            prompt: "How many?".into(),
161            answer_type: Some("number".into()),
162            max_words: None,
163            default: None,
164            validate: Some(Validate {
165                range: Some([1.0, 10.0]),
166            }),
167        }
168    }
169
170    #[tokio::test]
171    async fn run_qa_populates_defaults_and_text() {
172        let qa = QaNode {
173            welcome: None,
174            questions: vec![
175                Question {
176                    id: "name".into(),
177                    prompt: "Name".into(),
178                    answer_type: None,
179                    max_words: Some(3),
180                    default: Some(json!("guest")),
181                    validate: None,
182                },
183                number_question(),
184            ],
185            fallback_agent: None,
186        };
187
188        let env = envelope_with_text(Some("I need 4 adapters"));
189        let mut state = serde_json::Value::Null;
190        run_qa(&qa, &env, &mut state, handlebars()).await.unwrap();
191
192        let obj = state.as_object().expect("state object");
193        assert_eq!(obj.get("name"), Some(&json!("guest")));
194        assert_eq!(obj.get("quantity"), Some(&json!(4.0)));
195    }
196
197    #[tokio::test]
198    async fn run_qa_errors_when_max_words_exceeded() {
199        let qa = QaNode {
200            welcome: None,
201            questions: vec![Question {
202                id: "desc".into(),
203                prompt: "Describe".into(),
204                answer_type: None,
205                max_words: Some(1),
206                default: None,
207                validate: None,
208            }],
209            fallback_agent: None,
210        };
211
212        let env = envelope_with_text(None);
213        let mut state = json!({"desc": "too many words"});
214        let err = run_qa(&qa, &env, &mut state, handlebars())
215            .await
216            .unwrap_err();
217        assert!(err.to_string().contains("max_words"));
218    }
219}