use serde_json::json;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokitai::{compose, tool, ToolProvider};
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Flight {
pub airline: String,
pub price: f64,
pub flight_no: String,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct BookingConfirmation {
pub flight_no: String,
pub confirmation_code: String,
}
static STEP_CALL_ORDER: AtomicUsize = AtomicUsize::new(0);
fn reset_step_counter() {
STEP_CALL_ORDER.store(0, Ordering::SeqCst);
}
pub struct TripPlanner {
pub email_recipient: String,
}
#[compose(
name = "book_trip",
steps = [search_flights, filter_by_price, book_flight, send_email]
)]
#[tool]
impl TripPlanner {
pub fn search_flights(&self, _origin: String, _dest: String) -> Vec<Flight> {
STEP_CALL_ORDER.fetch_add(1, Ordering::SeqCst);
vec![
Flight {
airline: "AC".to_string(),
price: 250.0,
flight_no: "AC101".to_string(),
},
Flight {
airline: "UA".to_string(),
price: 420.0,
flight_no: "UA202".to_string(),
},
Flight {
airline: "DL".to_string(),
price: 180.0,
flight_no: "DL303".to_string(),
},
]
}
pub fn filter_by_price(&self, flights: Vec<Flight>, max_price: f64) -> Vec<Flight> {
STEP_CALL_ORDER.fetch_add(1, Ordering::SeqCst);
flights
.into_iter()
.filter(|f| f.price <= max_price)
.collect()
}
pub fn book_flight(&self, flights: Vec<Flight>) -> BookingConfirmation {
STEP_CALL_ORDER.fetch_add(1, Ordering::SeqCst);
let cheapest = flights
.into_iter()
.min_by(|a, b| a.price.partial_cmp(&b.price).unwrap())
.expect("at least one flight");
BookingConfirmation {
flight_no: cheapest.flight_no.clone(),
confirmation_code: format!("CONF-{}", cheapest.flight_no),
}
}
pub fn send_email(&self, confirmation: BookingConfirmation) -> String {
STEP_CALL_ORDER.fetch_add(1, Ordering::SeqCst);
format!(
"Email sent to {}: booking {} confirmed ({})",
self.email_recipient, confirmation.flight_no, confirmation.confirmation_code
)
}
}
#[test]
fn test_compose_4step_schema_has_one_entry() {
let tools = TripPlanner::tool_definitions();
let composed: Vec<_> = tools.iter().filter(|t| t.name == "book_trip").collect();
assert_eq!(
composed.len(),
1,
"expected exactly one `book_trip` tool entry; got {}",
composed.len()
);
}
#[test]
fn test_compose_4step_input_schema_is_first_step() {
let tools = TripPlanner::tool_definitions();
let book_trip = tools
.iter()
.find(|t| t.name == "book_trip")
.expect("book_trip tool not registered");
let schema: serde_json::Value =
serde_json::from_str(&book_trip.input_schema).expect("schema is valid JSON");
let props = schema
.get("properties")
.and_then(|v| v.as_object())
.expect("schema must carry `properties`");
let keys: std::collections::BTreeSet<&str> = props.keys().map(|s| s.as_str()).collect();
let expected: std::collections::BTreeSet<&str> =
["origin", "dest", "max_price"].into_iter().collect();
assert_eq!(
keys, expected,
"composed tool's input schema must be the first step's args + pass-through; got {:?}",
keys
);
}
#[test]
fn test_compose_4step_end_to_end_invocation() {
reset_step_counter();
let planner = TripPlanner {
email_recipient: "alice@example.com".to_string(),
};
let result = planner
.call_tool(
"book_trip",
&json!({
"origin": "SFO",
"dest": "JFK",
"max_price": 300.0,
}),
)
.expect("book_trip call should succeed");
let result_str = result
.as_str()
.expect("composed tool result should be a string (last step's return type)");
assert!(
result_str.contains("DL303"),
"expected the cheapest filtered flight (DL303) in the email; got: {}",
result_str
);
assert!(
result_str.contains("alice@example.com"),
"expected the email recipient in the body; got: {}",
result_str
);
assert_eq!(
STEP_CALL_ORDER.load(Ordering::SeqCst),
4,
"expected 4 step invocations (one per step); got {}",
STEP_CALL_ORDER.load(Ordering::SeqCst)
);
}
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Weather {
pub city: String,
pub temperature_c: f64,
pub humidity: f64,
}
pub struct WeatherService;
#[compose(name = "weather_summary", steps = [fetch_weather, summarize])]
#[tool]
impl WeatherService {
pub fn fetch_weather(&self, city: String) -> Weather {
Weather {
city,
temperature_c: 22.5,
humidity: 65.0,
}
}
pub fn summarize(&self, w: Weather) -> String {
format!(
"Weather in {}: {:.1}ยฐC, humidity {:.0}%",
w.city, w.temperature_c, w.humidity
)
}
}
#[test]
fn test_compose_2step_with_custom_struct_intermediate() {
let service = WeatherService;
let result = service
.call_tool("weather_summary", &json!({ "city": "Paris" }))
.expect("weather_summary call should succeed");
let result_str = result
.as_str()
.expect("composed tool result should be a string");
assert!(
result_str.contains("Paris") && result_str.contains("22.5"),
"expected the summary to mention the city and the temperature; got: {}",
result_str
);
}
pub struct UncomposedTripPlanner {
pub email_recipient: String,
}
#[tool]
impl UncomposedTripPlanner {
pub fn search_flights(&self, origin: String, dest: String) -> Vec<Flight> {
let _ = (origin, dest);
vec![]
}
pub fn filter_by_price(&self, flights: Vec<Flight>, max_price: f64) -> Vec<Flight> {
let _ = (flights, max_price);
vec![]
}
pub fn book_flight(&self, flights: Vec<Flight>) -> BookingConfirmation {
let _ = flights;
BookingConfirmation {
flight_no: String::new(),
confirmation_code: String::new(),
}
}
pub fn send_email(&self, confirmation: BookingConfirmation) -> String {
let _ = confirmation;
String::new()
}
}
#[test]
fn test_compose_token_savings_assertion() {
let composed_tools = TripPlanner::tool_definitions();
let composed_count = composed_tools
.iter()
.filter(|t| t.name == "book_trip")
.count();
assert_eq!(
composed_count, 1,
"composed impl must have exactly one `book_trip` entry; got {}",
composed_count
);
let uncomposed_tools = UncomposedTripPlanner::tool_definitions();
assert_eq!(
uncomposed_tools.len(),
4,
"un-composed impl must expose 4 entries (one per step)"
);
let composed_entry = composed_tools
.iter()
.find(|t| t.name == "book_trip")
.expect("book_trip tool not registered");
let composed_bytes: usize = composed_entry.name.len()
+ composed_entry.description.len()
+ composed_entry.input_schema.len();
let uncomposed_bytes: usize = uncomposed_tools
.iter()
.map(|t| t.name.len() + t.description.len() + t.input_schema.len())
.sum();
let delta = uncomposed_bytes.saturating_sub(composed_bytes);
let ratio = if composed_bytes == 0 {
0.0
} else {
uncomposed_bytes as f64 / composed_bytes as f64
};
eprintln!(
"[T-017] token-savings: composed_entry={} bytes, un-composed_union={} bytes, \
delta={} bytes, ratio={:.2}x",
composed_bytes, uncomposed_bytes, delta, ratio
);
assert!(
composed_bytes < uncomposed_bytes,
"composed entry ({} bytes) must be smaller than the union of 4 un-composed entries \
({} bytes); got ratio {:.2}x",
composed_bytes,
uncomposed_bytes,
ratio
);
assert!(
ratio >= 1.2,
"expected at least 1.2x byte reduction from compose on a 4-step chain; got {:.2}x",
ratio
);
}
pub struct Calculator;
#[tool]
impl Calculator {
pub fn add(&self, a: i32, b: i32) -> i32 {
a + b
}
}
#[test]
fn test_backwards_compat_plain_tool_unchanged() {
let calc = Calculator;
let tools = Calculator::tool_definitions();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0].name, "add");
let result = calc
.call_tool("add", &json!({"a": 2, "b": 3}))
.expect("add call should succeed");
assert_eq!(result, json!(5));
}
#[test]
fn test_compose_negative_compile_fail() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/ui/compose_chain_mismatch.rs");
}