1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use rusqlite::{Connection, Error, params};
use serde::{Deserialize, Serialize};
use dirs::data_dir;
use std::error::Error as StdError;
use std::path::PathBuf;
use dotenv::dotenv;
use async_openai::{
    types::CreateCompletionRequestArgs,
    Client,
};

#[derive(Debug, Serialize, Deserialize)]
pub struct Plan {
    pub steps: Vec<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
    pub id: i32,
    pub workspace_id: i32,
    pub sender: String,
    pub text: String,
    pub created_at: String,
}

const WORKSPACE_ID: i32 = 1;

pub fn get_database_path() -> Result<PathBuf, Box<dyn StdError>> {
    let mut path = data_dir().ok_or("Could not find data directory.")?;
    path.push("genau");
    path.push("genau.db");
    Ok(path)
}

pub fn get_db() -> Result<Connection, Error> {
    let path = get_database_path().unwrap();
    let conn = Connection::open(path)?;
    create_tables(&conn)?;
    Ok(conn)
}

pub fn save_message(conn: &Connection, message: &Message) -> Result<(), Error> {
    conn.execute(
        "INSERT INTO messages (workspace_id, sender, text, created_at) VALUES (?1, ?2, ?3, datetime('now'))",
        params![message.workspace_id, message.sender, message.text],
    )?;

    Ok(())
}

pub fn reset_database(conn: &Connection) -> Result<(), Error> {
    conn.execute(
        "DELETE FROM messages WHERE workspace_id = ?1",
        [WORKSPACE_ID],
    )?;
    conn.execute("DELETE FROM plans WHERE workspace_id = ?1", [WORKSPACE_ID])?;
    Ok(())
}

pub fn save_plan(conn: &Connection, workspace_id: i32, plan: &Plan) -> Result<(), Error> {
    let steps = serde_json::to_string(&plan.steps).unwrap();
    conn.execute(
        "INSERT OR REPLACE INTO plans (workspace_id, steps) VALUES (?1, ?2)",
        params![workspace_id, steps],
    )?;

    Ok(())
}

pub fn get_messages(conn: &Connection) -> Result<Vec<Message>, Error> {
    let mut stmt =
        conn.prepare("SELECT * FROM messages WHERE workspace_id = ?1 ORDER BY id DESC")?;
    let rows = stmt.query_map([WORKSPACE_ID], |row| {
        Ok(Message {
            id: row.get(0)?,
            workspace_id: row.get(1)?,
            sender: row.get(2)?,
            text: row.get(3)?,
            created_at: row.get(4)?,
        })
    })?;

    rows.collect()
}

pub fn create_tables(conn: &Connection) -> Result<(), Error> {
    conn.execute(
        "CREATE TABLE IF NOT EXISTS messages (
            id              INTEGER PRIMARY KEY,
            workspace_id    INTEGER NOT NULL,
            sender          TEXT NOT NULL,
            text            TEXT NOT NULL,
            created_at      TEXT NOT NULL
        )",
        [],
    )?;

    conn.execute(
        "CREATE TABLE IF NOT EXISTS plans (
            workspace_id    INTEGER PRIMARY KEY,
            steps           TEXT NOT NULL
        )",
        [],
    )?;

    Ok(())
}

pub fn get_plan(conn: &Connection) -> Result<Plan, Error> {
    let mut stmt = conn.prepare("SELECT steps FROM plans WHERE workspace_id = ?1")?;
    let mut rows = stmt.query_map([WORKSPACE_ID], |row| {
        let steps: String = row.get(0)?;
        let steps = serde_json::from_str(&steps).unwrap();
        Ok(Plan { steps })
    })?;

    if let Some(plan) = rows.next().transpose()? {
        Ok(plan)
    } else {
        let plan = Plan { steps: vec![] };
        save_plan(conn, WORKSPACE_ID, &plan)?;
        Ok(plan)
    }
}

pub async fn do_stuff() {
    println!("Hello, world!");
    dotenv().ok();

    let client = Client::new();

    let request = CreateCompletionRequestArgs::default()
        .model("text-davinci-003")
        .prompt("Tell me the recipe of alfredo pasta")
        .max_tokens(40_u16)
        .build()
        .unwrap();

    // Call API
    let response = client
        .completions() // Get the API "group" (completions, images, etc.) from the client
        .create(request) // Make the API call in that "group"
        .await
        .unwrap();

    dbg!(response);
}