Skip to main content

custom_tools/
custom_tools.rs

1use antigravity_sdk_rust::agent::{Agent, AgentConfig};
2use antigravity_sdk_rust::policy;
3use antigravity_sdk_rust::tools::Tool;
4use antigravity_sdk_rust::types::GeminiConfig;
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use tracing_subscriber::EnvFilter;
9
10struct LookupSkuTool;
11
12impl Tool for LookupSkuTool {
13    fn name(&self) -> &'static str {
14        "lookup_fruit_sku"
15    }
16
17    fn description(&self) -> &'static str {
18        "Looks up the SKU for a given fruit."
19    }
20
21    fn parameters_json_schema(&self) -> &'static str {
22        r#"{
23            "type": "object",
24            "properties": {
25                "fruit_name": {
26                    "type": "string"
27                }
28            },
29            "required": ["fruit_name"]
30        }"#
31    }
32
33    async fn call(&self, args: Value) -> Result<Value, anyhow::Error> {
34        let fruit_name = args
35            .get("fruit_name")
36            .and_then(Value::as_str)
37            .ok_or_else(|| anyhow::anyhow!("Missing fruit_name"))?
38            .to_lowercase();
39
40        let sku = match fruit_name.trim_end_matches('s') {
41            "apple" => "SKU-APP-123",
42            "banana" => "SKU-BAN-456",
43            "orange" => "SKU-ORA-789",
44            _ => "SKU-GEN-000",
45        };
46
47        Ok(Value::String(format!(
48            "SKU for {} is {}. Order ID for restocking: ORD-{}-NEW",
49            fruit_name, sku, sku
50        )))
51    }
52}
53
54struct RecordFruitTool {
55    inventory: Arc<Mutex<HashMap<String, u32>>>,
56}
57
58impl Tool for RecordFruitTool {
59    fn name(&self) -> &'static str {
60        "record_fruit"
61    }
62
63    fn description(&self) -> &'static str {
64        "Records the count of fruits by SKU."
65    }
66
67    fn parameters_json_schema(&self) -> &'static str {
68        r#"{
69            "type": "object",
70            "properties": {
71                "sku": {
72                    "type": "string"
73                },
74                "count": {
75                    "type": "integer"
76                }
77            },
78            "required": ["sku", "count"]
79        }"#
80    }
81
82    async fn call(&self, args: Value) -> Result<Value, anyhow::Error> {
83        let sku = args
84            .get("sku")
85            .and_then(Value::as_str)
86            .ok_or_else(|| anyhow::anyhow!("Missing sku"))?
87            .to_string();
88
89        let count = args
90            .get("count")
91            .and_then(Value::as_u64)
92            .ok_or_else(|| anyhow::anyhow!("Missing count"))? as u32;
93
94        let mut inv = self
95            .inventory
96            .lock()
97            .map_err(|e| anyhow::anyhow!("Mutex lock poisoned: {}", e))?;
98        let entry = inv.entry(sku.clone()).or_insert(0);
99        *entry += count;
100        let total = *entry;
101        drop(inv);
102
103        Ok(Value::String(format!(
104            "Recorded {} units for {}. Total count is now {}.",
105            count, sku, total
106        )))
107    }
108}
109
110#[tokio::main]
111async fn main() -> Result<(), anyhow::Error> {
112    // Initialize tracing subscriber
113    tracing_subscriber::fmt()
114        .with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
115        .init();
116
117    // Load environment variables from .env file if present
118    dotenvy::dotenv().ok();
119
120    let mut config = AgentConfig::default();
121
122    if let Ok(harness_path) = std::env::var("ANTIGRAVITY_HARNESS_PATH") {
123        config.binary_path = Some(harness_path);
124    }
125
126    let mut gemini_config = GeminiConfig::default();
127    if let Ok(api_key) = std::env::var("GEMINI_API_KEY") {
128        gemini_config.api_key = Some(api_key);
129    }
130    gemini_config.models.default.name = "gemini-3.5-flash".to_string();
131    config.gemini_config = gemini_config;
132
133    config.system_instructions = Some(antigravity_sdk_rust::types::SystemInstructions::Custom(
134        antigravity_sdk_rust::types::CustomSystemInstructions {
135            text:
136                "You keep track of fruit inventory. To record fruits, you MUST first look up the \
137                   fruit's SKU using lookup_fruit_sku, and then use that SKU with record_fruit."
138                    .to_string(),
139        },
140    ));
141
142    // Initialize shared mutable state for stateful tool
143    let inventory = Arc::new(Mutex::new(HashMap::new()));
144
145    // Register our custom tools
146    config.tools = vec![
147        Arc::new(LookupSkuTool),
148        Arc::new(RecordFruitTool {
149            inventory: inventory.clone(),
150        }),
151    ];
152
153    // Restrict agent to ONLY these tools
154    config.policies = Some(vec![
155        policy::deny_all(),
156        policy::allow("lookup_fruit_sku"),
157        policy::allow("record_fruit"),
158    ]);
159
160    let mut agent = Agent::new(config);
161    println!("Starting agent...");
162    agent.start().await?;
163
164    println!("  === Custom Tools Demo ===");
165
166    // Turn 1: Lookup fruit SKU
167    let prompt1 = "What is the SKU for apples? We need to order more.";
168    println!("\n  User: {}", prompt1);
169    let response1 = agent.chat(prompt1).await?;
170    println!("  Agent: {}", response1.text);
171
172    // Stateful interaction: record fruits across multiple turns
173    println!("\n  === Stateful Tool (Fruit Counter) Demo ===");
174
175    let turns = vec![
176        "I have 5 apples.",
177        "And I just got 3 bananas.",
178        "Oh, and another 2 apples.",
179    ];
180
181    for user_input in turns {
182        println!("\n  User: {}", user_input);
183        let response = agent.chat(user_input).await?;
184        println!("  Agent: {}", response.text);
185    }
186
187    agent.stop().await?;
188    Ok(())
189}