1use crate::constants::{MAX_TOKENS, OPENAI_ENDPOINT};
2use crate::parsers::{format_vision_prompt, get_last_assistant_message};
3use crate::screen::{add_grid_to_image, capture_screen_with_cursor};
4use crate::types::{
5 ImageMessage, ImageMessageContent, ImageUrl, Message, OpenAIRequest, Role, TextMessage,
6};
7use base64::{engine::general_purpose, Engine as _};
8use reqwest::Client;
9use serde_json::Value;
10use std::{env, fs, path::Path, thread, time::Duration};
11
12pub async fn get_next_action_from_openai(
13 messages: &mut Vec<Message>,
14 objective: &str,
15 grid_interval: i32,
16) -> Result<String, String> {
17 thread::sleep(Duration::from_secs(1));
18
19 let screenshots_dir = "screenshots";
20 if !Path::new(screenshots_dir).exists() {
21 fs::create_dir(screenshots_dir)
22 .map_err(|e| format!("Failed to create directory: {}", e))?;
23 }
24
25 let screenshot_filename = format!("{}/screenshot.png", screenshots_dir);
26 capture_screen_with_cursor(&screenshot_filename)
27 .map_err(|e| format!("Error capturing screen: {}", e))?;
28
29 let new_screenshot_filename = format!("{}/screenshot_with_grid.png", screenshots_dir);
30 add_grid_to_image(
31 &screenshot_filename,
32 &new_screenshot_filename,
33 grid_interval,
34 )
35 .map_err(|e| format!("Error adding grid to image: {}", e))?;
36
37 thread::sleep(Duration::from_secs(1));
38
39 let img_file = fs::read(&new_screenshot_filename)
40 .map_err(|e| format!("Error reading screenshot file: {}", e))?;
41 let img_base64 = general_purpose::STANDARD.encode(&img_file);
42
43 let mut previous_action = get_last_assistant_message(messages);
44 let vision_prompt = format_vision_prompt(objective, &mut previous_action);
45
46 let vision_message = Message::ImageMessage(ImageMessage {
47 role: Role::User,
48 content: vec![
49 ImageMessageContent::Text {
50 text: vision_prompt,
51 },
52 ImageMessageContent::ImageUrl {
53 image_url: ImageUrl {
54 url: format!("data:image/jpeg;base64,{}", img_base64),
55 },
56 },
57 ],
58 });
59
60 let mut messages_clone = messages.clone();
61 messages_clone.push(vision_message);
62
63 let payload = OpenAIRequest {
64 model: "gpt-4-vision-preview".to_string(),
65 messages: messages_clone,
66 max_tokens: MAX_TOKENS,
67 };
68
69 let content = send_message_to_openai(payload)
70 .await
71 .map_err(|e| format!("Error sending message to OpenAI: {}", e))?;
72
73 messages.push(Message::TextMessage(TextMessage {
74 role: Role::User,
75 content: "screenshot.png".to_string(),
76 }));
77
78 messages.push(Message::TextMessage(TextMessage {
79 role: Role::Assistant,
80 content: content.to_string(),
81 }));
82
83 Ok(content.replace("\\", ""))
84}
85
86pub async fn send_message_to_openai(payload: OpenAIRequest) -> Result<String, String> {
87 let client = Client::new();
88
89 let openai_api_key = env::var("OPENAI_API_KEY")
90 .map_err(|_| "OPENAI_API_KEY not found in environment".to_string())?;
91
92 let seralized_payload = serde_json::to_string(&payload)
93 .map_err(|e| format!("Failed to serialize payload: {}", e))?;
94
95 let response: Value = client
96 .post(OPENAI_ENDPOINT)
97 .header("Content-Type", "application/json")
98 .header("Authorization", format!("Bearer {}", openai_api_key))
99 .body(seralized_payload)
100 .send()
101 .await
102 .map_err(|e| format!("Request failed: {}", e))?
103 .json()
104 .await
105 .map_err(|e| format!("Failed to parse JSON: {}", e))?;
106
107 let content = response["choices"][0]["message"]["content"].to_string();
108
109 Ok(content)
110}