use anyhow::{Result, anyhow};
use serde::Serialize;
use pidge_client::{AuthClient, GraphClient};
use pidge_core::Config;
use crate::cli::ClassifyArgs;
use crate::commands::classify_cache::{ClassifyCache, cache_key};
use crate::commands::classify_model::{AilloyModel, LabelModel, build_input};
use crate::commands::classify_parse::{parse_labels, validate_labels};
use crate::commands::mail_fragment::resolve;
use futures::stream::{self, StreamExt};
#[derive(Serialize)]
struct ClassifyOut {
hash: Option<String>,
from: Option<String>,
classification: Vec<String>,
}
pub fn resolve_prompt(
prompt: Option<String>,
prompt_file: Option<String>,
config_prompt: Option<String>,
) -> Result<String> {
if let Some(p) = prompt {
return Ok(p);
}
if let Some(f) = prompt_file {
let s = if f == "-" {
use std::io::Read;
let mut s = String::new();
std::io::stdin().read_to_string(&mut s)?;
s
} else {
std::fs::read_to_string(&f)?
};
return Ok(s);
}
config_prompt.ok_or_else(|| {
anyhow!("no prompt: pass --prompt/--prompt-file or set one with `pidge config set classify.prompt`")
})
}
pub async fn run(args: ClassifyArgs, json: bool) -> Result<()> {
if !crate::commands::ai::is_ai_active() {
return Err(anyhow!(
"AI is not active for pidge. Run `pidge ai enable` and `pidge ai config` to set up a provider."
));
}
let config = Config::load()?;
let prompt = resolve_prompt(
args.prompt.clone(),
args.prompt_file.clone(),
config.classify.prompt.clone(),
)?;
let allowed = if args.labels.is_empty() {
config.classify.labels.clone()
} else {
args.labels.clone()
};
if let Some(text) = args.text.clone() {
let model = AilloyModel::new()?;
let raw = model.classify(&prompt, &text).await?;
let labels = validate_labels(parse_labels(&raw), &allowed);
emit_single(None, None, labels, json)?;
return Ok(());
}
if let Some(fragment) = args.fragment.clone() {
let (short, msg) = resolve(&fragment)?;
let g = GraphClient::new(AuthClient::from_env()?)?;
let full = g.get_message(&msg.account, &msg.graph_id).await?;
let input = build_input(&full.subject, &full.from.address, &body_text(&full));
let model = AilloyModel::new()?;
let labels = validate_labels(
parse_labels(&model.classify(&prompt, &input).await?),
&allowed,
);
if args.set_category && worth_categorizing(&labels) {
g.set_categories(&msg.account, &msg.graph_id, &labels)
.await?;
}
emit_single(Some(short), Some(full.from.address), labels, json)?;
return Ok(());
}
run_batch(args, &prompt, &allowed, json).await
}
#[allow(clippy::type_complexity)]
pub(crate) async fn run_batch(
args: ClassifyArgs,
prompt: &str,
allowed: &[String],
json: bool,
) -> Result<()> {
let config = Config::load()?;
let parallel = args
.parallel
.or(config.classify.parallel)
.unwrap_or(4)
.max(1);
let use_cache = !args.no_cache && config.classify.cache.unwrap_or(true);
let g = GraphClient::new(AuthClient::from_env()?)?;
let messages = select_messages(&g, &args).await?;
let model = AilloyModel::new()?;
let mut cache = if use_cache {
ClassifyCache::load()
} else {
ClassifyCache::default()
};
let prepared: Vec<(String, String, String, String, String, Option<Vec<String>>)> = messages
.into_iter()
.map(|(account, id, short, from)| {
let key = cache_key(&id, prompt);
let hit = if use_cache { cache.get(&key) } else { None };
(account, id, short, from, key, hit)
})
.collect();
let set_category = args.set_category;
let tasks = prepared
.into_iter()
.map(|(account, id, short, from, key, hit)| {
let g = &g;
let model = &model;
async move {
let (labels, fresh) = match hit {
Some(l) => (l, false),
None => match classify_one(g, model, &account, &id, prompt, allowed).await {
Ok(l) => (l, true),
Err(e) => {
eprintln!(" ! {short}: {e}");
(vec!["unknown".to_string()], false)
}
},
};
if set_category && worth_categorizing(&labels) {
if let Err(e) = g.set_categories(&account, &id, &labels).await {
eprintln!(" ! {short}: set-category failed: {e}");
}
}
(short, from, key, labels, fresh)
}
});
let results: Vec<(String, String, String, Vec<String>, bool)> = stream::iter(tasks)
.buffer_unordered(parallel)
.collect()
.await;
let mut out = Vec::new();
for (short, from, key, labels, fresh) in results {
if use_cache && fresh {
cache.put(key, labels.clone());
}
out.push(ClassifyOut {
hash: Some(short),
from: Some(from),
classification: labels,
});
}
if use_cache {
cache.save();
}
if json {
println!("{}", serde_json::to_string_pretty(&out)?);
} else {
for o in &out {
println!(
"{} {} {}",
o.hash.as_deref().unwrap_or(""),
o.from.as_deref().unwrap_or(""),
o.classification.join(", ")
);
}
}
Ok(())
}
async fn classify_one<M: LabelModel>(
g: &GraphClient,
model: &M,
account: &str,
id: &str,
prompt: &str,
allowed: &[String],
) -> Result<Vec<String>> {
let full = g.get_message(account, id).await?;
let input = build_input(&full.subject, &full.from.address, &body_text(&full));
let raw = model.classify(prompt, &input).await?;
Ok(validate_labels(parse_labels(&raw), allowed))
}
#[allow(clippy::type_complexity)]
async fn select_messages(
g: &GraphClient,
args: &ClassifyArgs,
) -> Result<Vec<(String, String, String, String)>> {
use pidge_core::short_hash;
let config = Config::load()?;
if config.accounts.is_empty() {
return Err(anyhow!("No accounts signed in. Run `pidge account add`."));
}
let target_emails: Vec<String> = if args.account.is_empty() {
config.accounts.iter().map(|a| a.email.clone()).collect()
} else {
for f in &args.account {
if config.find(f).is_none() {
return Err(anyhow!("not signed in to {f}"));
}
}
args.account.clone()
};
let cutoff = args
.older_than
.as_deref()
.map(crate::commands::mail_delete::parse_older_than)
.transpose()?;
let fetch = args.limit.unwrap_or(100).max(50);
let mut out = Vec::new();
for email in &target_emails {
let msgs: Vec<pidge_core::Message> = if !args.from.is_empty() {
let mut v = Vec::new();
for sender in &args.from {
let found = g
.search_messages(email, &format!("from:{sender}"), 1000)
.await
.unwrap_or_default();
let s = sender.to_ascii_lowercase();
v.extend(
found
.into_iter()
.filter(|m| m.from.address.to_ascii_lowercase() == s),
);
}
v
} else if let Some(folder) = &args.folder {
match crate::commands::mail_folders::resolve_folder_path(g, email, folder).await? {
Some(id) => g
.list_folder(email, &id, fetch, 0, false)
.await
.map(|p| p.messages)
.unwrap_or_default(),
None => {
eprintln!(" ! no folder '{folder}' in {email}");
Vec::new()
}
}
} else {
g.list_inbox(email, fetch, 0, false)
.await
.map(|p| p.messages)
.unwrap_or_default()
};
for m in msgs {
if let Some(c) = cutoff {
if m.received_at >= c {
continue;
}
}
out.push((
email.clone(),
m.id.clone(),
short_hash(&m.id),
m.from.address.clone(),
));
}
}
if let Some(limit) = args.limit {
out.truncate(limit);
}
Ok(out)
}
pub(crate) fn body_text(m: &pidge_core::FullMessage) -> String {
use pidge_core::BodyContentType;
match m.body_content_type {
BodyContentType::Text => m.body_content.clone(),
BodyContentType::Html => html2text::from_read(m.body_content.as_bytes(), 100),
}
}
fn worth_categorizing(labels: &[String]) -> bool {
!(labels.len() == 1 && labels[0] == "unknown")
}
fn emit_single(
hash: Option<String>,
from: Option<String>,
labels: Vec<String>,
json: bool,
) -> Result<()> {
if json {
println!(
"{}",
serde_json::to_string(&ClassifyOut {
hash,
from,
classification: labels
})?
);
} else {
for l in labels {
println!("{l}");
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_prompt_prefers_flag() {
assert_eq!(
resolve_prompt(Some("flag".into()), None, Some("cfg".into())).unwrap(),
"flag"
);
}
#[test]
fn resolve_prompt_falls_back_to_config() {
assert_eq!(
resolve_prompt(None, None, Some("cfg".into())).unwrap(),
"cfg"
);
}
#[test]
fn resolve_prompt_errors_when_unset() {
assert!(resolve_prompt(None, None, None).is_err());
}
#[test]
fn worth_categorizing_rejects_lone_unknown() {
assert!(!worth_categorizing(&["unknown".to_string()]));
assert!(worth_categorizing(&["receipt".to_string()]));
assert!(worth_categorizing(&[
"unknown".to_string(),
"receipt".to_string()
]));
}
}