use anyhow::Result;
use rig::pipeline::{self, TryOp, agent_ops};
use rig::try_parallel;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::support::assert_nonempty_response;
use super::support;
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
struct Names {
names: Vec<String>,
}
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
struct Topics {
topics: Vec<String>,
}
#[derive(Debug, Deserialize, JsonSchema, Serialize)]
struct Sentiment {
sentiment: f64,
confidence: f64,
}
#[tokio::test]
#[ignore = "requires a local llama.cpp OpenAI-compatible server"]
async fn batch_multi_extract_chain() -> Result<()> {
let client = support::completions_client();
let model = support::model_name();
let names_extractor = client
.extractor::<Names>(model.clone())
.preamble("Extract names from the given text.")
.retries(2)
.build();
let topics_extractor = client
.extractor::<Topics>(model.clone())
.preamble("Extract topics from the given text.")
.retries(2)
.build();
let sentiment_extractor = client
.extractor::<Sentiment>(model)
.preamble("Extract sentiment and confidence from the given text.")
.retries(2)
.build();
let chain = pipeline::new()
.chain(try_parallel!(
agent_ops::extract(names_extractor),
agent_ops::extract(topics_extractor),
agent_ops::extract(sentiment_extractor),
))
.map_ok(|(names, topics, sentiment)| {
format!(
"Extracted names: {}\nExtracted topics: {}\nExtracted sentiment: {} ({})",
names.names.join(", "),
topics.topics.join(", "),
sentiment.sentiment,
sentiment.confidence,
)
});
let responses = chain
.try_batch_call(
4,
vec![
"Screw you Putin!",
"I love my dog, but I hate my cat.",
"I'm going to the store to buy some milk.",
],
)
.await?;
assert_eq!(responses.len(), 3);
for response in responses {
assert_nonempty_response(&response);
}
Ok(())
}