1use rusqlite::{Connection, Error, params};
2use serde::{Deserialize, Serialize};
3use dirs::data_dir;
4use std::error::Error as StdError;
5use std::path::PathBuf;
6use dotenv::dotenv;
7use async_openai::{
8 types::CreateCompletionRequestArgs,
9 Client,
10};
11
12#[derive(Debug, Serialize, Deserialize)]
13pub struct Plan {
14 pub steps: Vec<String>,
15}
16
17#[derive(Debug, Serialize, Deserialize)]
18pub struct Message {
19 pub id: i32,
20 pub workspace_id: i32,
21 pub sender: String,
22 pub text: String,
23 pub created_at: String,
24}
25
26const WORKSPACE_ID: i32 = 1;
27
28pub fn get_database_path() -> Result<PathBuf, Box<dyn StdError>> {
29 let mut path = data_dir().ok_or("Could not find data directory.")?;
30 path.push("genau");
31 path.push("genau.db");
32 Ok(path)
33}
34
35pub fn get_db() -> Result<Connection, Error> {
36 let path = get_database_path().unwrap();
37 let conn = Connection::open(path)?;
38 create_tables(&conn)?;
39 Ok(conn)
40}
41
42pub fn save_message(conn: &Connection, message: &Message) -> Result<(), Error> {
43 conn.execute(
44 "INSERT INTO messages (workspace_id, sender, text, created_at) VALUES (?1, ?2, ?3, datetime('now'))",
45 params![message.workspace_id, message.sender, message.text],
46 )?;
47
48 Ok(())
49}
50
51pub fn reset_database(conn: &Connection) -> Result<(), Error> {
52 conn.execute(
53 "DELETE FROM messages WHERE workspace_id = ?1",
54 [WORKSPACE_ID],
55 )?;
56 conn.execute("DELETE FROM plans WHERE workspace_id = ?1", [WORKSPACE_ID])?;
57 Ok(())
58}
59
60pub fn save_plan(conn: &Connection, workspace_id: i32, plan: &Plan) -> Result<(), Error> {
61 let steps = serde_json::to_string(&plan.steps).unwrap();
62 conn.execute(
63 "INSERT OR REPLACE INTO plans (workspace_id, steps) VALUES (?1, ?2)",
64 params![workspace_id, steps],
65 )?;
66
67 Ok(())
68}
69
70pub fn get_messages(conn: &Connection) -> Result<Vec<Message>, Error> {
71 let mut stmt =
72 conn.prepare("SELECT * FROM messages WHERE workspace_id = ?1 ORDER BY id DESC")?;
73 let rows = stmt.query_map([WORKSPACE_ID], |row| {
74 Ok(Message {
75 id: row.get(0)?,
76 workspace_id: row.get(1)?,
77 sender: row.get(2)?,
78 text: row.get(3)?,
79 created_at: row.get(4)?,
80 })
81 })?;
82
83 rows.collect()
84}
85
86pub fn create_tables(conn: &Connection) -> Result<(), Error> {
87 conn.execute(
88 "CREATE TABLE IF NOT EXISTS messages (
89 id INTEGER PRIMARY KEY,
90 workspace_id INTEGER NOT NULL,
91 sender TEXT NOT NULL,
92 text TEXT NOT NULL,
93 created_at TEXT NOT NULL
94 )",
95 [],
96 )?;
97
98 conn.execute(
99 "CREATE TABLE IF NOT EXISTS plans (
100 workspace_id INTEGER PRIMARY KEY,
101 steps TEXT NOT NULL
102 )",
103 [],
104 )?;
105
106 Ok(())
107}
108
109pub fn get_plan(conn: &Connection) -> Result<Plan, Error> {
110 let mut stmt = conn.prepare("SELECT steps FROM plans WHERE workspace_id = ?1")?;
111 let mut rows = stmt.query_map([WORKSPACE_ID], |row| {
112 let steps: String = row.get(0)?;
113 let steps = serde_json::from_str(&steps).unwrap();
114 Ok(Plan { steps })
115 })?;
116
117 if let Some(plan) = rows.next().transpose()? {
118 Ok(plan)
119 } else {
120 let plan = Plan { steps: vec![] };
121 save_plan(conn, WORKSPACE_ID, &plan)?;
122 Ok(plan)
123 }
124}
125
126pub async fn do_stuff() {
127 println!("Hello, world!");
128 dotenv().ok();
129
130 let client = Client::new();
131
132 let request = CreateCompletionRequestArgs::default()
133 .model("text-davinci-003")
134 .prompt("Tell me the recipe of alfredo pasta")
135 .max_tokens(40_u16)
136 .build()
137 .unwrap();
138
139 let response = client
141 .completions() .create(request) .await
144 .unwrap();
145
146 dbg!(response);
147}