use rig::{completion::ToolDefinition, tool::Tool};
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::{
error::SenseiError,
lua,
rcon_ext::{execute_lua_json, SharedRcon},
};
pub struct GetPowerStats {
pub(crate) rcon: SharedRcon,
}
impl GetPowerStats {
pub const fn new(rcon: SharedRcon) -> Self {
Self { rcon }
}
}
#[derive(Debug, Deserialize)]
pub struct GetPowerStatsArgs {}
#[derive(Debug, Serialize, Deserialize, PartialEq)]
pub struct PowerStats {
pub production_watts: f64,
pub consumption_watts: f64,
pub satisfaction: f64,
}
impl Tool for GetPowerStats {
const NAME: &'static str = "get_power_stats";
type Error = SenseiError;
type Args = GetPowerStatsArgs;
type Output = PowerStats;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "get_power_stats".to_string(),
description: "Get the power grid statistics: total production, consumption, and satisfaction ratio".to_string(),
parameters: json!({
"type": "object",
"properties": {}
}),
}
}
async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
let lua = lua::power_stats();
let json = execute_lua_json(&self.rcon, &lua).await?;
Ok(serde_json::from_str(&json)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_power_stats() {
let json = r#"{"production_watts":5000000,"consumption_watts":3500000,"satisfaction":1.0}"#;
let stats: PowerStats = serde_json::from_str(json).unwrap();
assert_eq!(stats.production_watts, 5_000_000.0);
assert_eq!(stats.consumption_watts, 3_500_000.0);
assert_eq!(stats.satisfaction, 1.0);
}
#[test]
fn test_parse_no_power() {
let json = r#"{"production_watts":0,"consumption_watts":0,"satisfaction":1.0}"#;
let stats: PowerStats = serde_json::from_str(json).unwrap();
assert_eq!(stats.production_watts, 0.0);
assert_eq!(stats.satisfaction, 1.0);
}
#[test]
fn test_parse_brownout() {
let json = r#"{"production_watts":1000,"consumption_watts":2000,"satisfaction":0.5}"#;
let stats: PowerStats = serde_json::from_str(json).unwrap();
assert_eq!(stats.satisfaction, 0.5);
}
}