genauai_kernel/
lib.rs

1use rusqlite::{Connection, Error, params};
2use serde::{Deserialize, Serialize};
3use dirs::data_dir;
4use std::error::Error as StdError;
5use std::path::PathBuf;
6use dotenv::dotenv;
7use async_openai::{
8    types::CreateCompletionRequestArgs,
9    Client,
10};
11
12#[derive(Debug, Serialize, Deserialize)]
13pub struct Plan {
14    pub steps: Vec<String>,
15}
16
17#[derive(Debug, Serialize, Deserialize)]
18pub struct Message {
19    pub id: i32,
20    pub workspace_id: i32,
21    pub sender: String,
22    pub text: String,
23    pub created_at: String,
24}
25
26const WORKSPACE_ID: i32 = 1;
27
28pub fn get_database_path() -> Result<PathBuf, Box<dyn StdError>> {
29    let mut path = data_dir().ok_or("Could not find data directory.")?;
30    path.push("genau");
31    path.push("genau.db");
32    Ok(path)
33}
34
35pub fn get_db() -> Result<Connection, Error> {
36    let path = get_database_path().unwrap();
37    let conn = Connection::open(path)?;
38    create_tables(&conn)?;
39    Ok(conn)
40}
41
42pub fn save_message(conn: &Connection, message: &Message) -> Result<(), Error> {
43    conn.execute(
44        "INSERT INTO messages (workspace_id, sender, text, created_at) VALUES (?1, ?2, ?3, datetime('now'))",
45        params![message.workspace_id, message.sender, message.text],
46    )?;
47
48    Ok(())
49}
50
51pub fn reset_database(conn: &Connection) -> Result<(), Error> {
52    conn.execute(
53        "DELETE FROM messages WHERE workspace_id = ?1",
54        [WORKSPACE_ID],
55    )?;
56    conn.execute("DELETE FROM plans WHERE workspace_id = ?1", [WORKSPACE_ID])?;
57    Ok(())
58}
59
60pub fn save_plan(conn: &Connection, workspace_id: i32, plan: &Plan) -> Result<(), Error> {
61    let steps = serde_json::to_string(&plan.steps).unwrap();
62    conn.execute(
63        "INSERT OR REPLACE INTO plans (workspace_id, steps) VALUES (?1, ?2)",
64        params![workspace_id, steps],
65    )?;
66
67    Ok(())
68}
69
70pub fn get_messages(conn: &Connection) -> Result<Vec<Message>, Error> {
71    let mut stmt =
72        conn.prepare("SELECT * FROM messages WHERE workspace_id = ?1 ORDER BY id DESC")?;
73    let rows = stmt.query_map([WORKSPACE_ID], |row| {
74        Ok(Message {
75            id: row.get(0)?,
76            workspace_id: row.get(1)?,
77            sender: row.get(2)?,
78            text: row.get(3)?,
79            created_at: row.get(4)?,
80        })
81    })?;
82
83    rows.collect()
84}
85
86pub fn create_tables(conn: &Connection) -> Result<(), Error> {
87    conn.execute(
88        "CREATE TABLE IF NOT EXISTS messages (
89            id              INTEGER PRIMARY KEY,
90            workspace_id    INTEGER NOT NULL,
91            sender          TEXT NOT NULL,
92            text            TEXT NOT NULL,
93            created_at      TEXT NOT NULL
94        )",
95        [],
96    )?;
97
98    conn.execute(
99        "CREATE TABLE IF NOT EXISTS plans (
100            workspace_id    INTEGER PRIMARY KEY,
101            steps           TEXT NOT NULL
102        )",
103        [],
104    )?;
105
106    Ok(())
107}
108
109pub fn get_plan(conn: &Connection) -> Result<Plan, Error> {
110    let mut stmt = conn.prepare("SELECT steps FROM plans WHERE workspace_id = ?1")?;
111    let mut rows = stmt.query_map([WORKSPACE_ID], |row| {
112        let steps: String = row.get(0)?;
113        let steps = serde_json::from_str(&steps).unwrap();
114        Ok(Plan { steps })
115    })?;
116
117    if let Some(plan) = rows.next().transpose()? {
118        Ok(plan)
119    } else {
120        let plan = Plan { steps: vec![] };
121        save_plan(conn, WORKSPACE_ID, &plan)?;
122        Ok(plan)
123    }
124}
125
126pub async fn do_stuff() {
127    println!("Hello, world!");
128    dotenv().ok();
129
130    let client = Client::new();
131
132    let request = CreateCompletionRequestArgs::default()
133        .model("text-davinci-003")
134        .prompt("Tell me the recipe of alfredo pasta")
135        .max_tokens(40_u16)
136        .build()
137        .unwrap();
138
139    // Call API
140    let response = client
141        .completions() // Get the API "group" (completions, images, etc.) from the client
142        .create(request) // Make the API call in that "group"
143        .await
144        .unwrap();
145
146    dbg!(response);
147}