gsm-runner 0.4.43

Greentic messaging package: runner.
Documentation
use crate::model::{AgentCfg, QaNode};
use anyhow::{Result, bail};
use gsm_core::MessageEnvelope;
use regex::Regex;
use serde_json::{Value, json};

pub async fn run_qa(
    cfg: &QaNode,
    env: &MessageEnvelope,
    state: &mut Value,
    _hbs: &handlebars::Handlebars<'static>,
) -> Result<()> {
    run_qa_inner(cfg, env, state, true).await
}

#[allow(dead_code)]
pub async fn run_qa_offline(cfg: &QaNode, env: &MessageEnvelope, state: &mut Value) -> Result<()> {
    run_qa_inner(cfg, env, state, false).await
}

async fn run_qa_inner(
    cfg: &QaNode,
    env: &MessageEnvelope,
    state: &mut Value,
    allow_agent: bool,
) -> Result<()> {
    if !state.is_object() {
        *state = json!({});
    }
    let obj = state.as_object_mut().unwrap();

    // Defaults
    for q in &cfg.questions {
        if let Some(def) = &q.default
            && !obj.contains_key(&q.id)
        {
            obj.insert(q.id.clone(), def.clone());
        }
    }

    let mut missing: Vec<&str> = cfg
        .questions
        .iter()
        .filter(|q| !obj.contains_key(&q.id))
        .map(|q| q.id.as_str())
        .collect();

    if !missing.is_empty()
        && let Some(text) = &env.text
    {
        let number_re = Regex::new(r"(?P<n>\d+)").unwrap();
        for q in &cfg.questions {
            if missing.contains(&q.id.as_str()) {
                match q.answer_type.as_deref() {
                    Some("number") => {
                        if let Some(caps) = number_re.captures(text)
                            && let Some(m) = caps.name("n")
                        {
                            obj.insert(q.id.clone(), json!(m.as_str().parse::<i64>().unwrap_or(1)));
                        }
                    }
                    _ => {
                        let loc = text
                            .split_whitespace()
                            .take(q.max_words.unwrap_or(3))
                            .collect::<Vec<_>>()
                            .join(" ");
                        if !loc.is_empty() {
                            obj.insert(q.id.clone(), json!(loc));
                        }
                    }
                }
            }
        }
        missing = cfg
            .questions
            .iter()
            .filter(|q| !obj.contains_key(&q.id))
            .map(|q| q.id.as_str())
            .collect();
    }

    if !missing.is_empty()
        && let Some(agent) = &cfg.fallback_agent
    {
        if !allow_agent {
            bail!("qa fallback agent requires network; offline execution not permitted");
        }
        let extracted = call_agent(agent, env).await?;
        for (k, v) in extracted.as_object().unwrap_or(&serde_json::Map::new()) {
            obj.insert(k.clone(), v.clone());
        }
    }

    // Validate
    for q in &cfg.questions {
        if let Some(val) = obj.get(&q.id).cloned() {
            if let Some(r) = q.validate.as_ref().and_then(|v| v.range) {
                let n = val
                    .as_f64()
                    .or_else(|| val.as_i64().map(|x| x as f64))
                    .unwrap_or(0.0);
                let clamped = n.clamp(r[0], r[1]);
                obj.insert(q.id.clone(), json!(clamped));
            }
            if let Some(maxw) = q.max_words {
                let s = val.as_str().unwrap_or_default();
                if s.split_whitespace().count() > maxw {
                    bail!("answer '{}' exceeds max_words {}", q.id, maxw);
                }
            }
        }
    }
    Ok(())
}

async fn call_agent(agent: &AgentCfg, env: &MessageEnvelope) -> Result<serde_json::Value> {
    let url = agent
        .endpoint
        .clone()
        .unwrap_or_else(|| "http://localhost:18080/agent/extract".into());
    let body = json!({
      "type": agent.r#type.as_deref().unwrap_or("ollama"),
      "model": agent.model.as_deref().unwrap_or("gemma:instruct"),
      "task": agent.task.as_deref().unwrap_or("extract keys"),
      "text": env.text,
    });
    let resp = reqwest::Client::new().post(url).json(&body).send().await?;
    let v: serde_json::Value = resp.json().await.unwrap_or_else(|_| json!({}));
    Ok(v.get("extracted").cloned().unwrap_or_else(|| json!({})))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::model::{QaNode, Question, Validate};
    use handlebars::Handlebars;

    fn handlebars() -> &'static Handlebars<'static> {
        Box::leak(Box::new(Handlebars::new()))
    }

    fn envelope_with_text(text: Option<&str>) -> MessageEnvelope {
        MessageEnvelope {
            tenant: "acme".into(),
            platform: gsm_core::Platform::Slack,
            chat_id: "C1".into(),
            user_id: "U1".into(),
            thread_id: None,
            msg_id: "msg-1".into(),
            text: text.map(|t| t.to_string()),
            timestamp: "2024-01-01T00:00:00Z".into(),
            context: Default::default(),
        }
    }

    fn number_question() -> Question {
        Question {
            id: "quantity".into(),
            prompt: "How many?".into(),
            answer_type: Some("number".into()),
            max_words: None,
            default: None,
            validate: Some(Validate {
                range: Some([1.0, 10.0]),
            }),
        }
    }

    #[tokio::test]
    async fn run_qa_populates_defaults_and_text() {
        let qa = QaNode {
            welcome: None,
            questions: vec![
                Question {
                    id: "name".into(),
                    prompt: "Name".into(),
                    answer_type: None,
                    max_words: Some(3),
                    default: Some(json!("guest")),
                    validate: None,
                },
                number_question(),
            ],
            fallback_agent: None,
        };

        let env = envelope_with_text(Some("I need 4 adapters"));
        let mut state = serde_json::Value::Null;
        run_qa(&qa, &env, &mut state, handlebars()).await.unwrap();

        let obj = state.as_object().expect("state object");
        assert_eq!(obj.get("name"), Some(&json!("guest")));
        assert_eq!(obj.get("quantity"), Some(&json!(4.0)));
    }

    #[tokio::test]
    async fn run_qa_errors_when_max_words_exceeded() {
        let qa = QaNode {
            welcome: None,
            questions: vec![Question {
                id: "desc".into(),
                prompt: "Describe".into(),
                answer_type: None,
                max_words: Some(1),
                default: None,
                validate: None,
            }],
            fallback_agent: None,
        };

        let env = envelope_with_text(None);
        let mut state = json!({"desc": "too many words"});
        let err = run_qa(&qa, &env, &mut state, handlebars())
            .await
            .unwrap_err();
        assert!(err.to_string().contains("max_words"));
    }
}