aiscript_vm/ai/
prompt.rs

1use openai_api_rs::v1::common::GPT3_5_TURBO;
2use tokio::runtime::Handle;
3
4pub struct PromptConfig {
5    pub input: String,
6    pub model: Option<String>,
7    pub max_tokens: Option<i64>,
8    pub temperature: Option<f64>,
9    pub system_prompt: Option<String>,
10}
11
12impl Default for PromptConfig {
13    fn default() -> Self {
14        Self {
15            input: String::new(),
16            model: Some(GPT3_5_TURBO.to_string()),
17            max_tokens: Default::default(),
18            temperature: Default::default(),
19            system_prompt: Default::default(),
20        }
21    }
22}
23
24#[cfg(feature = "ai_test")]
25async fn _prompt_with_config(config: PromptConfig) -> String {
26    return format!("AI: {}", config.input);
27}
28
29#[cfg(not(feature = "ai_test"))]
30async fn _prompt_with_config(mut config: PromptConfig) -> String {
31    use openai_api_rs::v1::{
32        chat_completion::{self, ChatCompletionRequest},
33        common::GPT3_5_TURBO,
34    };
35
36    let mut client = super::openai_client();
37
38    // Create system message if provided
39    let mut messages = Vec::new();
40    if let Some(system_prompt) = config.system_prompt.take() {
41        messages.push(chat_completion::ChatCompletionMessage {
42            role: chat_completion::MessageRole::system,
43            content: chat_completion::Content::Text(system_prompt),
44            name: None,
45            tool_calls: None,
46            tool_call_id: None,
47        });
48    }
49
50    // Add user message
51    messages.push(chat_completion::ChatCompletionMessage {
52        role: chat_completion::MessageRole::user,
53        content: chat_completion::Content::Text(config.input),
54        name: None,
55        tool_calls: None,
56        tool_call_id: None,
57    });
58
59    // Build the request
60    let mut req = ChatCompletionRequest::new(
61        config
62            .model
63            .take()
64            .unwrap_or_else(|| GPT3_5_TURBO.to_string()),
65        messages,
66    );
67
68    if let Some(max_tokens) = config.max_tokens {
69        req.max_tokens = Some(max_tokens);
70    }
71
72    if let Some(temperature) = config.temperature {
73        req.temperature = Some(temperature);
74    }
75
76    let result = client.chat_completion(req).await.unwrap();
77    result.choices[0]
78        .message
79        .content
80        .clone()
81        .unwrap_or_default()
82}
83
84pub fn prompt_with_config(config: PromptConfig) -> String {
85    if Handle::try_current().is_ok() {
86        // We're in an async context, use await
87        Handle::current().block_on(async { _prompt_with_config(config).await })
88    } else {
89        // We're in a sync context, create a new runtime
90        let runtime = tokio::runtime::Runtime::new().unwrap();
91        runtime.block_on(async { _prompt_with_config(config).await })
92    }
93}