auto_pilot/
action.rs

1use crate::constants::{MAX_TOKENS, OPENAI_ENDPOINT};
2use crate::parsers::{format_vision_prompt, get_last_assistant_message};
3use crate::screen::{add_grid_to_image, capture_screen_with_cursor};
4use crate::types::{
5    ImageMessage, ImageMessageContent, ImageUrl, Message, OpenAIRequest, Role, TextMessage,
6};
7use base64::{engine::general_purpose, Engine as _};
8use reqwest::Client;
9use serde_json::Value;
10use std::{env, fs, path::Path, thread, time::Duration};
11
12pub async fn get_next_action_from_openai(
13    messages: &mut Vec<Message>,
14    objective: &str,
15    grid_interval: i32,
16) -> Result<String, String> {
17    thread::sleep(Duration::from_secs(1));
18
19    let screenshots_dir = "screenshots";
20    if !Path::new(screenshots_dir).exists() {
21        fs::create_dir(screenshots_dir)
22            .map_err(|e| format!("Failed to create directory: {}", e))?;
23    }
24
25    let screenshot_filename = format!("{}/screenshot.png", screenshots_dir);
26    capture_screen_with_cursor(&screenshot_filename)
27        .map_err(|e| format!("Error capturing screen: {}", e))?;
28
29    let new_screenshot_filename = format!("{}/screenshot_with_grid.png", screenshots_dir);
30    add_grid_to_image(
31        &screenshot_filename,
32        &new_screenshot_filename,
33        grid_interval,
34    )
35    .map_err(|e| format!("Error adding grid to image: {}", e))?;
36
37    thread::sleep(Duration::from_secs(1));
38
39    let img_file = fs::read(&new_screenshot_filename)
40        .map_err(|e| format!("Error reading screenshot file: {}", e))?;
41    let img_base64 = general_purpose::STANDARD.encode(&img_file);
42
43    let mut previous_action = get_last_assistant_message(messages);
44    let vision_prompt = format_vision_prompt(objective, &mut previous_action);
45
46    let vision_message = Message::ImageMessage(ImageMessage {
47        role: Role::User,
48        content: vec![
49            ImageMessageContent::Text {
50                text: vision_prompt,
51            },
52            ImageMessageContent::ImageUrl {
53                image_url: ImageUrl {
54                    url: format!("data:image/jpeg;base64,{}", img_base64),
55                },
56            },
57        ],
58    });
59
60    let mut messages_clone = messages.clone();
61    messages_clone.push(vision_message);
62
63    let payload = OpenAIRequest {
64        model: "gpt-4-vision-preview".to_string(),
65        messages: messages_clone,
66        max_tokens: MAX_TOKENS,
67    };
68
69    let content = send_message_to_openai(payload)
70        .await
71        .map_err(|e| format!("Error sending message to OpenAI: {}", e))?;
72
73    messages.push(Message::TextMessage(TextMessage {
74        role: Role::User,
75        content: "screenshot.png".to_string(),
76    }));
77
78    messages.push(Message::TextMessage(TextMessage {
79        role: Role::Assistant,
80        content: content.to_string(),
81    }));
82
83    Ok(content.replace("\\", ""))
84}
85
86pub async fn send_message_to_openai(payload: OpenAIRequest) -> Result<String, String> {
87    let client = Client::new();
88
89    let openai_api_key = env::var("OPENAI_API_KEY")
90        .map_err(|_| "OPENAI_API_KEY not found in environment".to_string())?;
91
92    let seralized_payload = serde_json::to_string(&payload)
93        .map_err(|e| format!("Failed to serialize payload: {}", e))?;
94
95    let response: Value = client
96        .post(OPENAI_ENDPOINT)
97        .header("Content-Type", "application/json")
98        .header("Authorization", format!("Bearer {}", openai_api_key))
99        .body(seralized_payload)
100        .send()
101        .await
102        .map_err(|e| format!("Request failed: {}", e))?
103        .json()
104        .await
105        .map_err(|e| format!("Failed to parse JSON: {}", e))?;
106
107    let content = response["choices"][0]["message"]["content"].to_string();
108
109    Ok(content)
110}