1use anyhow::{anyhow, Result};
10use serde::{Deserialize, Serialize};
11use std::env;
12use std::time::Duration;
13use tracing::{debug, info, warn};
14
15#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct ComputerUseFunctionCall {
20 pub name: String,
22 #[serde(default)]
24 pub args: serde_json::Value,
25 pub id: Option<String>,
27}
28
29#[derive(Debug, Clone)]
31pub struct ComputerUseResponse {
32 pub completed: bool,
34 pub function_call: Option<ComputerUseFunctionCall>,
36 pub text: Option<String>,
38 pub safety_decision: Option<String>,
40}
41
42#[derive(Debug, Serialize, Clone)]
44pub struct ComputerUsePreviousAction {
45 pub name: String,
47 pub response: ComputerUseActionResponse,
49 pub screenshot: String,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub url: Option<String>,
54}
55
56#[derive(Debug, Serialize, Clone)]
58pub struct ComputerUseActionResponse {
59 pub success: bool,
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub error: Option<String>,
64}
65
66#[derive(Debug, Clone, Serialize)]
68pub struct ComputerUseStep {
69 pub step: u32,
71 pub action: String,
73 pub args: serde_json::Value,
75 pub success: bool,
77 pub error: Option<String>,
79 pub text: Option<String>,
81}
82
83#[derive(Debug, Clone, Serialize)]
85pub struct ComputerUseResult {
86 pub status: String,
88 pub goal: String,
90 pub steps_executed: u32,
92 pub final_action: String,
94 pub final_text: Option<String>,
96 pub steps: Vec<ComputerUseStep>,
98 pub pending_confirmation: Option<serde_json::Value>,
100 pub execution_id: Option<String>,
102}
103
104pub type ProgressCallback = Box<dyn Fn(&ComputerUseStep) + Send + Sync>;
106
107#[derive(Debug, Deserialize)]
111struct ComputerUseBackendResponse {
112 completed: bool,
113 #[serde(default)]
114 function_call: Option<ComputerUseFunctionCall>,
115 text: Option<String>,
116 safety_decision: Option<String>,
117 #[allow(dead_code)]
118 duration_ms: Option<u64>,
119 #[allow(dead_code)]
120 model_used: Option<String>,
121 error: Option<String>,
122}
123
124pub async fn call_computer_use_backend(
141 base64_image: &str,
142 goal: &str,
143 previous_actions: Option<&[ComputerUsePreviousAction]>,
144) -> Result<ComputerUseResponse> {
145 let backend_url = env::var("GEMINI_COMPUTER_USE_BACKEND_URL")
146 .unwrap_or_else(|_| "https://app.mediar.ai/api/vision/computer-use".to_string());
147
148 info!(
149 "[computer_use] Calling backend at {} (goal: {})",
150 backend_url,
151 &goal[..goal.len().min(50)]
152 );
153
154 let client = reqwest::Client::builder()
155 .timeout(Duration::from_secs(300))
156 .build()?;
157
158 let payload = serde_json::json!({
159 "image": base64_image,
160 "goal": goal,
161 "previous_actions": previous_actions.unwrap_or(&[])
162 });
163
164 let resp = client
165 .post(&backend_url)
166 .header("Content-Type", "application/json")
167 .json(&payload)
168 .send()
169 .await?;
170
171 let status = resp.status();
172 if !status.is_success() {
173 let text = resp.text().await.unwrap_or_default();
174 warn!("[computer_use] Backend error: {} - {}", status, text);
175 return Err(anyhow!("Computer Use backend error ({}): {}", status, text));
176 }
177
178 let response_text = resp.text().await?;
179 debug!(
180 "[computer_use] Backend response: {}",
181 &response_text[..response_text.len().min(500)]
182 );
183
184 let backend_response: ComputerUseBackendResponse = serde_json::from_str(&response_text)
185 .map_err(|e| anyhow!("Failed to parse backend response: {}", e))?;
186
187 if let Some(error) = backend_response.error {
188 return Err(anyhow!("Computer Use error: {}", error));
189 }
190
191 Ok(ComputerUseResponse {
192 completed: backend_response.completed,
193 function_call: backend_response.function_call,
194 text: backend_response.text,
195 safety_decision: backend_response.safety_decision,
196 })
197}
198
199pub fn translate_gemini_keys(gemini_keys: &str) -> Result<String, String> {
213 let parts: Vec<&str> = gemini_keys.split('+').collect();
214 let mut result = String::new();
215
216 for (i, part) in parts.iter().enumerate() {
217 let lower = part.trim().to_lowercase();
218 let is_last = i == parts.len() - 1;
219
220 let translated: &str = match lower.as_str() {
221 "control" | "ctrl" => "{Ctrl}",
223 "alt" => "{Alt}",
224 "shift" => "{Shift}",
225 "meta" | "cmd" | "command" | "win" | "windows" | "super" => "{Win}",
226
227 "enter" | "return" => "{Enter}",
229 "tab" => "{Tab}",
230 "escape" | "esc" => "{Escape}",
231 "backspace" | "back" => "{Backspace}",
232 "delete" | "del" => "{Delete}",
233 "space" => "{Space}",
234 "insert" | "ins" => "{Insert}",
235 "home" => "{Home}",
236 "end" => "{End}",
237 "pageup" | "pgup" => "{PageUp}",
238 "pagedown" | "pgdown" | "pgdn" => "{PageDown}",
239 "printscreen" | "prtsc" => "{PrintScreen}",
240
241 "up" | "arrowup" => "{Up}",
243 "down" | "arrowdown" => "{Down}",
244 "left" | "arrowleft" => "{Left}",
245 "right" | "arrowright" => "{Right}",
246
247 s if s.starts_with('f') && s.len() >= 2 => {
249 if let Ok(num) = s[1..].parse::<u8>() {
250 if (1..=24).contains(&num) {
251 match num {
252 1 => "{F1}",
253 2 => "{F2}",
254 3 => "{F3}",
255 4 => "{F4}",
256 5 => "{F5}",
257 6 => "{F6}",
258 7 => "{F7}",
259 8 => "{F8}",
260 9 => "{F9}",
261 10 => "{F10}",
262 11 => "{F11}",
263 12 => "{F12}",
264 13 => "{F13}",
265 14 => "{F14}",
266 15 => "{F15}",
267 16 => "{F16}",
268 17 => "{F17}",
269 18 => "{F18}",
270 19 => "{F19}",
271 20 => "{F20}",
272 21 => "{F21}",
273 22 => "{F22}",
274 23 => "{F23}",
275 24 => "{F24}",
276 _ => unreachable!(),
277 }
278 } else {
279 return Err(format!(
280 "Invalid function key '{}' in '{}'. Use f1-f24.",
281 part, gemini_keys
282 ));
283 }
284 } else {
285 return Err(format!(
286 "Invalid function key '{}' in '{}'. Use f1-f24.",
287 part, gemini_keys
288 ));
289 }
290 }
291
292 s if s.len() == 1 && is_last => {
294 result.push_str(s);
295 continue;
296 }
297
298 unknown => {
300 return Err(format!(
301 "Unknown key '{}' in combination '{}'. Valid: enter, tab, escape, \
302 backspace, delete, space, up/down/left/right, home, end, pageup, \
303 pagedown, f1-f24, or modifiers (ctrl, alt, shift, meta) with letters.",
304 unknown, gemini_keys
305 ));
306 }
307 };
308
309 result.push_str(translated);
310 }
311
312 Ok(result)
313}
314
315#[allow(clippy::too_many_arguments)]
336pub fn convert_normalized_to_screen(
337 norm_x: f64,
338 norm_y: f64,
339 window_x: f64,
340 window_y: f64,
341 screenshot_w: f64,
342 screenshot_h: f64,
343 dpi_scale: f64,
344 resize_scale: f64,
345) -> (f64, f64) {
346 let px_x = (norm_x / 1000.0) * screenshot_w;
348 let px_y = (norm_y / 1000.0) * screenshot_h;
349 let px_x = px_x / resize_scale;
351 let px_y = px_y / resize_scale;
352 let logical_x = px_x / dpi_scale;
354 let logical_y = px_y / dpi_scale;
355 let screen_x = window_x + logical_x;
357 let screen_y = window_y + logical_y;
358 (screen_x, screen_y)
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_translate_gemini_keys_simple() {
367 assert_eq!(translate_gemini_keys("enter").unwrap(), "{Enter}");
368 assert_eq!(translate_gemini_keys("tab").unwrap(), "{Tab}");
369 assert_eq!(translate_gemini_keys("escape").unwrap(), "{Escape}");
370 }
371
372 #[test]
373 fn test_translate_gemini_keys_modifiers() {
374 assert_eq!(translate_gemini_keys("control+a").unwrap(), "{Ctrl}a");
375 assert_eq!(translate_gemini_keys("ctrl+c").unwrap(), "{Ctrl}c");
376 assert_eq!(
377 translate_gemini_keys("Meta+Shift+T").unwrap(),
378 "{Win}{Shift}t"
379 );
380 }
381
382 #[test]
383 fn test_translate_gemini_keys_function() {
384 assert_eq!(translate_gemini_keys("f1").unwrap(), "{F1}");
385 assert_eq!(translate_gemini_keys("f12").unwrap(), "{F12}");
386 assert_eq!(translate_gemini_keys("alt+f4").unwrap(), "{Alt}{F4}");
387 }
388
389 #[test]
390 fn test_convert_normalized_coords() {
391 let (x, y) = convert_normalized_to_screen(500.0, 500.0, 0.0, 0.0, 1000.0, 1000.0, 1.0, 1.0);
393 assert!((x - 500.0).abs() < 0.001);
394 assert!((y - 500.0).abs() < 0.001);
395
396 let (x, y) =
398 convert_normalized_to_screen(500.0, 500.0, 100.0, 200.0, 1000.0, 1000.0, 1.0, 1.0);
399 assert!((x - 600.0).abs() < 0.001);
400 assert!((y - 700.0).abs() < 0.001);
401 }
402}