Skip to main content

custom_tools/
custom_tools.rs

1use antigravity_sdk_rust::agent::Agent;
2use antigravity_sdk_rust::policy;
3use antigravity_sdk_rust::tools::Tool;
4use serde_json::Value;
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use tracing_subscriber::EnvFilter;
8
9struct LookupSkuTool;
10
11impl Tool for LookupSkuTool {
12    fn name(&self) -> &'static str {
13        "lookup_fruit_sku"
14    }
15
16    fn description(&self) -> &'static str {
17        "Looks up the SKU for a given fruit."
18    }
19
20    fn parameters_json_schema(&self) -> &'static str {
21        r#"{
22            "type": "object",
23            "properties": {
24                "fruit_name": {
25                    "type": "string"
26                }
27            },
28            "required": ["fruit_name"]
29        }"#
30    }
31
32    async fn call(&self, args: Value) -> Result<Value, anyhow::Error> {
33        let fruit_name = args
34            .get("fruit_name")
35            .and_then(Value::as_str)
36            .ok_or_else(|| anyhow::anyhow!("Missing fruit_name"))?
37            .to_lowercase();
38
39        let sku = match fruit_name.trim_end_matches('s') {
40            "apple" => "SKU-APP-123",
41            "banana" => "SKU-BAN-456",
42            "orange" => "SKU-ORA-789",
43            _ => "SKU-GEN-000",
44        };
45
46        Ok(Value::String(format!(
47            "SKU for {} is {}. Order ID for restocking: ORD-{}-NEW",
48            fruit_name, sku, sku
49        )))
50    }
51}
52
53struct RecordFruitTool {
54    inventory: Arc<Mutex<HashMap<String, u32>>>,
55}
56
57impl Tool for RecordFruitTool {
58    fn name(&self) -> &'static str {
59        "record_fruit"
60    }
61
62    fn description(&self) -> &'static str {
63        "Records the count of fruits by SKU."
64    }
65
66    fn parameters_json_schema(&self) -> &'static str {
67        r#"{
68            "type": "object",
69            "properties": {
70                "sku": {
71                    "type": "string"
72                },
73                "count": {
74                    "type": "integer"
75                }
76            },
77            "required": ["sku", "count"]
78        }"#
79    }
80
81    async fn call(&self, args: Value) -> Result<Value, anyhow::Error> {
82        let sku = args
83            .get("sku")
84            .and_then(Value::as_str)
85            .ok_or_else(|| anyhow::anyhow!("Missing sku"))?
86            .to_string();
87
88        let count = args
89            .get("count")
90            .and_then(Value::as_u64)
91            .ok_or_else(|| anyhow::anyhow!("Missing count"))? as u32;
92
93        let mut inv = self
94            .inventory
95            .lock()
96            .map_err(|e| anyhow::anyhow!("Mutex lock poisoned: {}", e))?;
97        let entry = inv.entry(sku.clone()).or_insert(0);
98        *entry += count;
99        let total = *entry;
100        drop(inv);
101
102        Ok(Value::String(format!(
103            "Recorded {} units for {}. Total count is now {}.",
104            count, sku, total
105        )))
106    }
107}
108
109#[tokio::main]
110async fn main() -> Result<(), anyhow::Error> {
111    // Initialize tracing subscriber
112    tracing_subscriber::fmt()
113        .with_env_filter(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
114        .init();
115
116    // Load environment variables from .env file if present
117    dotenvy::dotenv().ok();
118
119    // Check if the user specified a binary path or check the environment variable
120    let harness_path = std::env::var("ANTIGRAVITY_HARNESS_PATH").ok();
121    let api_key = std::env::var("GEMINI_API_KEY").ok();
122
123    let mut builder = Agent::builder();
124    if let Some(path) = harness_path {
125        builder = builder.binary_path(path);
126    }
127    if let Some(key) = api_key {
128        builder = builder.api_key(key);
129    }
130
131    // Initialize shared mutable state for stateful tool
132    let inventory = Arc::new(Mutex::new(HashMap::new()));
133
134    let agent = builder
135        .default_model("gemini-3.5-flash")
136        .system_instructions(antigravity_sdk_rust::types::SystemInstructions::Custom(
137            antigravity_sdk_rust::types::CustomSystemInstructions {
138                text: "You keep track of fruit inventory. To record fruits, you MUST first look up the \
139                       fruit's SKU using lookup_fruit_sku, and then use that SKU with record_fruit."
140                        .to_string(),
141            },
142        ))
143        .tools(vec![
144            Arc::new(LookupSkuTool),
145            Arc::new(RecordFruitTool {
146                inventory: inventory.clone(),
147            }),
148        ])
149        .policies(vec![
150            policy::deny_all(),
151            policy::allow("lookup_fruit_sku"),
152            policy::allow("record_fruit"),
153        ])
154        .build();
155
156    println!("Starting agent...");
157    let agent = agent.start().await?;
158
159    println!("  === Custom Tools Demo ===");
160
161    // Turn 1: Lookup fruit SKU
162    let prompt1 = "What is the SKU for apples? We need to order more.";
163    println!("\n  User: {}", prompt1);
164    let response1 = agent.chat(prompt1).await?;
165    println!("  Agent: {}", response1.text);
166
167    // Stateful interaction: record fruits across multiple turns
168    println!("\n  === Stateful Tool (Fruit Counter) Demo ===");
169
170    let turns = vec![
171        "I have 5 apples.",
172        "And I just got 3 bananas.",
173        "Oh, and another 2 apples.",
174    ];
175
176    for user_input in turns {
177        println!("\n  User: {}", user_input);
178        let response = agent.chat(user_input).await?;
179        println!("  Agent: {}", response.text);
180    }
181
182    agent.stop().await?;
183    Ok(())
184}