use anyhow::{Result, anyhow};
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use crate::provider::{LlmRequest, Message, MessageContent, Role};
use crate::provider::registry::ProviderRegistry;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubQuery {
pub q: String,
pub intent: Intent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum Intent {
Weather { location: String },
Currency { from: String, to: String },
Timezone { location: String },
Wikipedia { topic: String },
GithubRepo { owner: String, repo: String },
Flight { from: String, to: String, date: String, trip: String },
Train { from: String, to: String, date: String },
Hotel { city: String, checkin: String },
Movie { query: String },
Concert { query: String },
Restaurant { query: String, city: String },
Shopping { query: String },
Stock { query: String },
Express { number: String },
News { query: String },
Map { query: String },
Translate { text: String, to: String },
CryptoPrice { coin: String },
Calendar { query: String },
UnitConvert { query: String },
Math { expr: String },
IpLookup { ip: String },
DnsLookup { domain: String },
Whois { domain: String },
Phone { number: String },
Idiom { query: String },
Poem { query: String },
Law { query: String },
Hospital { query: String },
Recipe { query: String },
Sports { query: String },
Lottery { query: String },
Academic { query: String },
Job { query: String, city: String },
Video { query: String },
Book { query: String },
Package { query: String, registry: String },
Forum { query: String },
General,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryPlan {
pub sub_queries: Vec<SubQuery>,
}
impl QueryPlan {
pub fn passthrough(query: &str) -> Self {
Self {
sub_queries: vec![SubQuery {
q: query.to_owned(),
intent: Intent::General,
}],
}
}
}
fn planner_system() -> String {
let now = chrono::Local::now();
let tz = now.format("%Z").to_string();
let ts = now.format("%Y-%m-%d %H:%M %A").to_string();
format!(
r#"Current time: {ts} ({tz})
You analyze a user search query and decide how to answer it.
Output ONLY valid JSON matching this schema (no prose, no markdown, no code fences):
{{"sub_queries":[{{"q":"<cleaned keywords>","intent":{{"kind":"<intent>", ...fields}}}}]}}
Intent kinds and required fields:
weather : {{"kind":"weather","location":"<city in English>"}}
currency : {{"kind":"currency","from":"<ISO code>","to":"<ISO code>"}}
timezone : {{"kind":"timezone","location":"<IANA zone ONLY, e.g. Asia/Shanghai>"}}
wikipedia : {{"kind":"wikipedia","topic":"<topic phrase>"}}
github_repo : {{"kind":"github_repo","owner":"<owner>","repo":"<name>"}}
flight : {{"kind":"flight","from":"<city>","to":"<city>","date":"<YYYY-MM-DD or empty>","trip":"oneway|roundtrip"}}
train : {{"kind":"train","from":"<city>","to":"<city>","date":"<YYYY-MM-DD or empty>"}}
hotel : {{"kind":"hotel","city":"<city>","checkin":"<YYYY-MM-DD or empty>"}}
movie : {{"kind":"movie","query":"<movie name or keyword>"}}
concert : {{"kind":"concert","query":"<artist or show name>"}}
restaurant : {{"kind":"restaurant","query":"<cuisine or keyword>","city":"<city>"}}
shopping : {{"kind":"shopping","query":"<product name>"}}
stock : {{"kind":"stock","query":"<stock name or code>"}}
express : {{"kind":"express","number":"<tracking number>"}}
news : {{"kind":"news","query":"<topic>"}}
map : {{"kind":"map","query":"<place or route>"}}
translate : {{"kind":"translate","text":"<text to translate>","to":"<target language code>"}}
crypto_price : {{"kind":"crypto_price","coin":"<coin id, e.g. bitcoin>"}}
calendar : {{"kind":"calendar","query":"<date question>"}}
unit_convert : {{"kind":"unit_convert","query":"<conversion expression>"}}
math : {{"kind":"math","expr":"<math expression, e.g. 123*456>"}}
ip_lookup : {{"kind":"ip_lookup","ip":"<IP address or empty for self>"}}
dns_lookup : {{"kind":"dns_lookup","domain":"<domain name>"}}
whois : {{"kind":"whois","domain":"<domain name>"}}
phone : {{"kind":"phone","number":"<phone number>"}}
idiom : {{"kind":"idiom","query":"<idiom or word>"}}
poem : {{"kind":"poem","query":"<poem title or keyword>"}}
law : {{"kind":"law","query":"<law question>"}}
hospital : {{"kind":"hospital","query":"<medical question>"}}
recipe : {{"kind":"recipe","query":"<dish name>"}}
sports : {{"kind":"sports","query":"<match or team>"}}
lottery : {{"kind":"lottery","query":"<lottery type>"}}
academic : {{"kind":"academic","query":"<paper topic or keyword>"}}
job : {{"kind":"job","query":"<job title or keyword>","city":"<city or empty>"}}
video : {{"kind":"video","query":"<video topic>"}}
book : {{"kind":"book","query":"<book title or topic>"}}
package : {{"kind":"package","query":"<package name>","registry":"npm|pypi|crates"}}
forum : {{"kind":"forum","query":"<discussion topic>"}}
general : {{"kind":"general"}}
Rules:
- SPLIT multi-entity queries: if the query asks about N cities/entities,
output N sub_queries.
- CLEAN keywords: drop filler words and dates. The current time is already
known, so never include dates/years in the "q" field for live-data queries
(weather, currency, stock, etc.).
- PREFER English city names for "weather" intent so API lookups hit.
- KEEP Chinese keywords for Chinese sites: movie, restaurant, concert, shopping,
stock, express, news, recipe, poem, idiom, sports, job, video, book, forum, law.
- For date fields: use YYYY-MM-DD if user specifies a date, empty string if not.
- If unsure of intent, use "general" — don't force a wrong match.
- Max 5 sub_queries. Never output more.
Examples:
Input: "曼谷、广州、武汉未来7天的天气"
Output: {{"sub_queries":[
{{"q":"Bangkok weather","intent":{{"kind":"weather","location":"Bangkok"}}}},
{{"q":"Guangzhou weather","intent":{{"kind":"weather","location":"Guangzhou"}}}},
{{"q":"Wuhan weather","intent":{{"kind":"weather","location":"Wuhan"}}}}
]}}
Input: "美元兑人民币汇率"
Output: {{"sub_queries":[
{{"q":"USD to CNY","intent":{{"kind":"currency","from":"USD","to":"CNY"}}}}
]}}
Input: "rust async fn 用法"
Output: {{"sub_queries":[
{{"q":"rust async fn usage","intent":{{"kind":"general"}}}}
]}}
Input: "tokio 仓库什么情况"
Output: {{"sub_queries":[
{{"q":"tokio-rs/tokio","intent":{{"kind":"github_repo","owner":"tokio-rs","repo":"tokio"}}}}
]}}
Input: "下周三北京飞曼谷的机票"
Output: {{"sub_queries":[
{{"q":"北京飞曼谷机票","intent":{{"kind":"flight","from":"北京","to":"曼谷","date":"","trip":"oneway"}}}}
]}}
Input: "茅台股价"
Output: {{"sub_queries":[
{{"q":"茅台股票","intent":{{"kind":"stock","query":"茅台"}}}}
]}}
Input: "顺丰 SF1234567890 到哪了"
Output: {{"sub_queries":[
{{"q":"SF1234567890","intent":{{"kind":"express","number":"SF1234567890"}}}}
]}}
Input: "比特币现在多少钱"
Output: {{"sub_queries":[
{{"q":"bitcoin price","intent":{{"kind":"crypto_price","coin":"bitcoin"}}}}
]}}
Input: "附近好吃的火锅"
Output: {{"sub_queries":[
{{"q":"火锅推荐","intent":{{"kind":"restaurant","query":"火锅","city":""}}}}
]}}
Input: "iPhone 16 多少钱"
Output: {{"sub_queries":[
{{"q":"iPhone 16 价格","intent":{{"kind":"shopping","query":"iPhone 16"}}}}
]}}
Input: "周杰伦演唱会门票"
Output: {{"sub_queries":[
{{"q":"周杰伦演唱会","intent":{{"kind":"concert","query":"周杰伦"}}}}
]}}
Input: "翻译 hello world 成中文"
Output: {{"sub_queries":[
{{"q":"hello world","intent":{{"kind":"translate","text":"hello world","to":"zh"}}}}
]}}"#)
}
pub async fn plan(
query: &str,
flash_model: &str,
providers: &ProviderRegistry,
) -> QueryPlan {
let fut = try_plan(query, flash_model, providers);
let result = tokio::time::timeout(std::time::Duration::from_secs(5), fut).await;
match result {
Ok(Ok(p)) => {
let intents: Vec<&str> = p.sub_queries.iter().map(|s| match &s.intent {
Intent::Weather { .. } => "weather",
Intent::Currency { .. } => "currency",
Intent::Timezone { .. } => "timezone",
Intent::Wikipedia { .. } => "wikipedia",
Intent::GithubRepo { .. } => "github_repo",
Intent::Flight { .. } => "flight",
Intent::Train { .. } => "train",
Intent::Hotel { .. } => "hotel",
Intent::Movie { .. } => "movie",
Intent::Concert { .. } => "concert",
Intent::Restaurant { .. } => "restaurant",
Intent::Shopping { .. } => "shopping",
Intent::Stock { .. } => "stock",
Intent::Express { .. } => "express",
Intent::News { .. } => "news",
Intent::Map { .. } => "map",
Intent::Translate { .. } => "translate",
Intent::CryptoPrice { .. } => "crypto_price",
Intent::Calendar { .. } => "calendar",
Intent::UnitConvert { .. } => "unit_convert",
Intent::Math { .. } => "math",
Intent::IpLookup { .. } => "ip_lookup",
Intent::DnsLookup { .. } => "dns_lookup",
Intent::Whois { .. } => "whois",
Intent::Phone { .. } => "phone",
Intent::Idiom { .. } => "idiom",
Intent::Poem { .. } => "poem",
Intent::Law { .. } => "law",
Intent::Hospital { .. } => "hospital",
Intent::Recipe { .. } => "recipe",
Intent::Sports { .. } => "sports",
Intent::Lottery { .. } => "lottery",
Intent::Academic { .. } => "academic",
Intent::Job { .. } => "job",
Intent::Video { .. } => "video",
Intent::Book { .. } => "book",
Intent::Package { .. } => "package",
Intent::Forum { .. } => "forum",
Intent::General => "general",
}).collect();
tracing::info!(
query = %query,
sub_count = p.sub_queries.len(),
intents = ?intents,
"query_planner: planned"
);
p
}
Ok(Err(e)) => {
tracing::warn!(
query = %query,
error = %e,
"query_planner: fallback to passthrough"
);
QueryPlan::passthrough(query)
}
Err(_) => {
tracing::warn!(query = %query, "query_planner: 5s timeout, fallback to passthrough");
QueryPlan::passthrough(query)
}
}
}
async fn try_plan(
query: &str,
flash_model: &str,
providers: &ProviderRegistry,
) -> Result<QueryPlan> {
let (provider_name, model_id) = providers.resolve_model(flash_model);
let provider = providers
.get(provider_name)
.map_err(|e| anyhow!("flash provider '{provider_name}' unavailable: {e}"))?;
let messages = vec![Message {
role: Role::User,
content: MessageContent::Text(format!(
"Query: {query}\n\nOutput the JSON plan now."
)),
}];
let req = LlmRequest {
model: model_id.to_owned(),
messages,
tools: vec![],
system: Some(planner_system()),
max_tokens: Some(400),
temperature: Some(0.0),
frequency_penalty: None,
thinking_budget: None,
kv_cache_mode: 0,
session_key: None,
};
let mut stream = provider.stream(req).await?;
let mut buf = String::new();
while let Some(ev) = stream.next().await {
use crate::provider::StreamEvent;
match ev? {
StreamEvent::TextDelta(t) => buf.push_str(&t),
StreamEvent::ReasoningDelta(_) => {}
StreamEvent::ToolCall { .. } => { }
StreamEvent::Done { .. } => break,
StreamEvent::Error(e) => return Err(anyhow!("planner stream error: {e}")),
}
}
let json = extract_json_object(&buf).ok_or_else(|| {
anyhow!("planner output has no JSON object; got: {}", truncate(&buf, 200))
})?;
tracing::debug!(raw = %truncate(json, 500), "query_planner: raw LLM output");
let plan: QueryPlan = serde_json::from_str(json).map_err(|e| {
anyhow!("planner JSON parse failed: {e}; raw: {}", truncate(json, 200))
})?;
if plan.sub_queries.is_empty() {
return Err(anyhow!("planner returned empty sub_queries"));
}
let mut plan = plan;
plan.sub_queries.truncate(5);
Ok(plan)
}
fn extract_json_object(s: &str) -> Option<&str> {
let bytes = s.as_bytes();
let start = bytes.iter().position(|&b| b == b'{')?;
let mut depth = 0i32;
let mut in_str = false;
let mut escape = false;
for (i, &b) in bytes.iter().enumerate().skip(start) {
if escape {
escape = false;
continue;
}
if in_str {
match b {
b'\\' => escape = true,
b'"' => in_str = false,
_ => {}
}
continue;
}
match b {
b'"' => in_str = true,
b'{' => depth += 1,
b'}' => {
depth -= 1;
if depth == 0 {
return Some(&s[start..=i]);
}
}
_ => {}
}
}
None
}
fn truncate(s: &str, n: usize) -> String {
if s.len() <= n { s.to_owned() } else { format!("{}…", &s[..n]) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_json_plain() {
let out = extract_json_object(r#"{"sub_queries":[]}"#).unwrap();
assert!(out.contains("sub_queries"));
}
#[test]
fn extract_json_with_fence() {
let s = "```json\n{\"sub_queries\":[]}\n```";
let out = extract_json_object(s).unwrap();
assert_eq!(out, r#"{"sub_queries":[]}"#);
}
#[test]
fn extract_json_with_prose() {
let s = "Sure, here's the plan: {\"sub_queries\":[{\"q\":\"x\",\"intent\":{\"kind\":\"general\"}}]}";
let out = extract_json_object(s).unwrap();
assert!(out.starts_with("{"));
assert!(out.ends_with("}"));
}
#[test]
fn extract_json_handles_nested() {
let s = r#"{"a":{"b":1},"c":"}"}"#;
let out = extract_json_object(s).unwrap();
assert_eq!(out, s);
}
#[test]
fn passthrough_preserves_query() {
let p = QueryPlan::passthrough("hello world");
assert_eq!(p.sub_queries.len(), 1);
assert_eq!(p.sub_queries[0].q, "hello world");
assert!(matches!(p.sub_queries[0].intent, Intent::General));
}
#[test]
fn parses_weather_multi() {
let json = r#"{"sub_queries":[
{"q":"Bangkok weather","intent":{"kind":"weather","location":"Bangkok"}},
{"q":"Guangzhou weather","intent":{"kind":"weather","location":"Guangzhou"}}
]}"#;
let p: QueryPlan = serde_json::from_str(json).unwrap();
assert_eq!(p.sub_queries.len(), 2);
match &p.sub_queries[0].intent {
Intent::Weather { location } => assert_eq!(location, "Bangkok"),
_ => panic!("expected weather"),
}
}
#[test]
fn parses_currency() {
let json = r#"{"sub_queries":[
{"q":"USD to CNY","intent":{"kind":"currency","from":"USD","to":"CNY"}}
]}"#;
let p: QueryPlan = serde_json::from_str(json).unwrap();
assert!(matches!(&p.sub_queries[0].intent, Intent::Currency { from, to } if from == "USD" && to == "CNY"));
}
}