mostro_client/
db.rs

1use crate::util::get_mcli_path;
2use anyhow::Result;
3use mostro_core::prelude::*;
4use nip06::FromMnemonic;
5use nostr_sdk::prelude::*;
6use sqlx::pool::Pool;
7use sqlx::Sqlite;
8use sqlx::SqlitePool;
9use std::fs::File;
10use std::path::Path;
11
12pub async fn connect() -> Result<Pool<Sqlite>> {
13    let mcli_dir = get_mcli_path();
14    let mcli_db_path = format!("{}/mcli.db", mcli_dir);
15    let db_url = format!("sqlite://{}", mcli_db_path);
16    let pool: Pool<Sqlite>;
17    if !Path::exists(Path::new(&mcli_db_path)) {
18        if let Err(res) = File::create(&mcli_db_path) {
19            println!("Error in creating db file: {}", res);
20            return Err(res.into());
21        }
22        pool = SqlitePool::connect(&db_url).await?;
23        println!("Creating database file with orders table...");
24        sqlx::query(
25            r#"
26          CREATE TABLE IF NOT EXISTS orders (
27              id TEXT PRIMARY KEY,
28              kind TEXT NOT NULL,
29              status TEXT NOT NULL,
30              amount INTEGER NOT NULL,
31              min_amount INTEGER,
32              max_amount INTEGER,
33              fiat_code TEXT NOT NULL,
34              fiat_amount INTEGER NOT NULL,
35              payment_method TEXT NOT NULL,
36              premium INTEGER NOT NULL,
37              trade_keys TEXT,
38              counterparty_pubkey TEXT,
39              is_mine BOOLEAN,
40              buyer_invoice TEXT,
41              request_id INTEGER,
42              created_at INTEGER,
43              expires_at INTEGER
44          );
45          CREATE TABLE IF NOT EXISTS users (
46              i0_pubkey char(64) PRIMARY KEY,
47              mnemonic TEXT,
48              last_trade_index INTEGER,
49              created_at INTEGER
50          );
51          "#,
52        )
53        .execute(&pool)
54        .await?;
55
56        let mnemonic = match Mnemonic::generate(12) {
57            Ok(m) => m.to_string(),
58            Err(e) => {
59                println!("Error generating mnemonic: {}", e);
60                return Err(e.into());
61            }
62        };
63        let user = User::new(mnemonic, &pool).await?;
64        println!("User created with pubkey: {}", user.i0_pubkey);
65    } else {
66        pool = SqlitePool::connect(&db_url).await?;
67
68        // Migration: Drop buyer_token and seller_token columns if they exist
69        migrate_remove_token_columns(&pool).await?;
70    }
71
72    Ok(pool)
73}
74
75async fn migrate_remove_token_columns(pool: &SqlitePool) -> Result<()> {
76    println!("Checking for legacy token columns...");
77
78    // Check if buyer_token column exists
79    let buyer_token_exists = sqlx::query_scalar::<_, i64>(
80        "SELECT COUNT(*) FROM pragma_table_info('orders') WHERE name = 'buyer_token'",
81    )
82    .fetch_one(pool)
83    .await?;
84
85    // Check if seller_token column exists
86    let seller_token_exists = sqlx::query_scalar::<_, i64>(
87        "SELECT COUNT(*) FROM pragma_table_info('orders') WHERE name = 'seller_token'",
88    )
89    .fetch_one(pool)
90    .await?;
91
92    // Drop buyer_token column if it exists
93    if buyer_token_exists > 0 {
94        println!("Removing legacy buyer_token column...");
95        match sqlx::query("ALTER TABLE orders DROP COLUMN buyer_token")
96            .execute(pool)
97            .await
98        {
99            Ok(_) => println!("Successfully removed buyer_token column"),
100            Err(e) => {
101                println!("Warning: Could not remove buyer_token column: {}", e);
102                // Continue execution - this is not critical
103            }
104        }
105    }
106
107    // Drop seller_token column if it exists
108    if seller_token_exists > 0 {
109        println!("Removing legacy seller_token column...");
110        match sqlx::query("ALTER TABLE orders DROP COLUMN seller_token")
111            .execute(pool)
112            .await
113        {
114            Ok(_) => println!("Successfully removed seller_token column"),
115            Err(e) => {
116                println!("Warning: Could not remove seller_token column: {}", e);
117                // Continue execution - this is not critical
118            }
119        }
120    }
121
122    if buyer_token_exists == 0 && seller_token_exists == 0 {
123        println!("No legacy token columns found - database is up to date");
124    }
125
126    Ok(())
127}
128
129#[derive(Debug, Default, Clone, sqlx::FromRow)]
130pub struct User {
131    /// The user's ID is the identity pubkey
132    pub i0_pubkey: String,
133    pub mnemonic: String,
134    pub last_trade_index: Option<i64>,
135    pub created_at: i64,
136}
137
138impl User {
139    pub async fn new(mnemonic: String, pool: &SqlitePool) -> Result<Self> {
140        let mut user = User::default();
141        let account = NOSTR_REPLACEABLE_EVENT_KIND as u32;
142        let i0_keys =
143            Keys::from_mnemonic_advanced(&mnemonic, None, Some(account), Some(0), Some(0))?;
144        user.i0_pubkey = i0_keys.public_key().to_string();
145        user.created_at = chrono::Utc::now().timestamp();
146        user.mnemonic = mnemonic;
147        sqlx::query(
148            r#"
149                  INSERT INTO users (i0_pubkey, mnemonic, created_at)
150                  VALUES (?, ?, ?)
151                "#,
152        )
153        .bind(&user.i0_pubkey)
154        .bind(&user.mnemonic)
155        .bind(user.created_at)
156        .execute(pool)
157        .await?;
158
159        Ok(user)
160    }
161    // Chainable setters
162    pub fn set_mnemonic(&mut self, mnemonic: String) -> &mut Self {
163        self.mnemonic = mnemonic;
164        self
165    }
166
167    pub fn set_last_trade_index(&mut self, last_trade_index: i64) -> &mut Self {
168        self.last_trade_index = Some(last_trade_index);
169        self
170    }
171
172    // Applying changes to the database
173    pub async fn save(&self, pool: &SqlitePool) -> Result<()> {
174        sqlx::query(
175            r#"
176              UPDATE users 
177              SET mnemonic = ?, last_trade_index = ?
178              WHERE i0_pubkey = ?
179              "#,
180        )
181        .bind(&self.mnemonic)
182        .bind(self.last_trade_index)
183        .bind(&self.i0_pubkey)
184        .execute(pool)
185        .await?;
186
187        Ok(())
188    }
189
190    pub async fn get(pool: &SqlitePool) -> Result<User> {
191        let user = sqlx::query_as::<_, User>(
192            r#"
193            SELECT i0_pubkey, mnemonic, last_trade_index, created_at
194            FROM users
195            LIMIT 1
196            "#,
197        )
198        .fetch_one(pool)
199        .await?;
200
201        Ok(user)
202    }
203
204    pub async fn get_last_trade_index(pool: SqlitePool) -> Result<i64> {
205        let user = User::get(&pool).await?;
206        match user.last_trade_index {
207            Some(index) => Ok(index),
208            None => Ok(0),
209        }
210    }
211
212    pub async fn get_next_trade_index(pool: SqlitePool) -> Result<i64> {
213        let last_trade_index = User::get_last_trade_index(pool).await?;
214        Ok(last_trade_index + 1)
215    }
216
217    pub async fn get_identity_keys(pool: &SqlitePool) -> Result<Keys> {
218        let user = User::get(pool).await?;
219        let account = NOSTR_REPLACEABLE_EVENT_KIND as u32;
220        let keys =
221            Keys::from_mnemonic_advanced(&user.mnemonic, None, Some(account), Some(0), Some(0))?;
222
223        Ok(keys)
224    }
225
226    pub async fn get_next_trade_keys(pool: &SqlitePool) -> Result<(Keys, i64)> {
227        let trade_index = User::get_next_trade_index(pool.clone()).await?;
228        let trade_keys = User::get_trade_keys(pool, trade_index).await?;
229
230        Ok((trade_keys, trade_index))
231    }
232
233    pub async fn get_trade_keys(pool: &SqlitePool, index: i64) -> Result<Keys> {
234        if index < 0 {
235            return Err(anyhow::anyhow!("Trade index cannot be negative"));
236        }
237        let user = User::get(pool).await?;
238        let account = NOSTR_REPLACEABLE_EVENT_KIND as u32;
239        let keys = Keys::from_mnemonic_advanced(
240            &user.mnemonic,
241            None,
242            Some(account),
243            Some(0),
244            Some(index as u32),
245        )?;
246
247        Ok(keys)
248    }
249}
250
251#[derive(Debug, Default, Clone, sqlx::FromRow)]
252pub struct Order {
253    pub id: Option<String>,
254    pub kind: Option<String>,
255    pub status: Option<String>,
256    pub amount: i64,
257    pub fiat_code: String,
258    pub min_amount: Option<i64>,
259    pub max_amount: Option<i64>,
260    pub fiat_amount: i64,
261    pub payment_method: String,
262    pub premium: i64,
263    pub trade_keys: Option<String>,
264    pub counterparty_pubkey: Option<String>,
265    pub is_mine: Option<bool>,
266    pub buyer_invoice: Option<String>,
267    pub request_id: Option<i64>,
268    pub created_at: Option<i64>,
269    pub expires_at: Option<i64>,
270}
271
272impl Order {
273    pub async fn new(
274        pool: &SqlitePool,
275        order: SmallOrder,
276        trade_keys: &Keys,
277        request_id: Option<i64>,
278    ) -> Result<Self> {
279        let trade_keys_hex = trade_keys.secret_key().to_secret_hex();
280        let id = match order.id {
281            Some(id) => id.to_string(),
282            None => uuid::Uuid::new_v4().to_string(),
283        };
284        let order = Order {
285            id: Some(id),
286            kind: order.kind.as_ref().map(|k| k.to_string()),
287            status: order.status.as_ref().map(|s| s.to_string()),
288            amount: order.amount,
289            fiat_code: order.fiat_code,
290            min_amount: order.min_amount,
291            max_amount: order.max_amount,
292            fiat_amount: order.fiat_amount,
293            payment_method: order.payment_method,
294            premium: order.premium,
295            trade_keys: Some(trade_keys_hex),
296            counterparty_pubkey: None,
297            is_mine: Some(true),
298            buyer_invoice: None,
299            request_id,
300            created_at: Some(chrono::Utc::now().timestamp()),
301            expires_at: None,
302        };
303
304        // Try insert; if id already exists, perform an update instead
305        let insert_result = order.insert_db(pool).await;
306
307        if let Err(e) = insert_result {
308            // If the error is due to unique constraint (id already present), update instead
309            // SQLite uses error code 1555 (constraint failed) or 2067 (unique constraint failed)
310            let is_unique_violation = match e.as_database_error() {
311                Some(db_err) => {
312                    let code = db_err.code().map(|c| c.to_string()).unwrap_or_default();
313                    code == "1555" || code == "2067"
314                }
315                None => false,
316            };
317
318            if is_unique_violation {
319                order.update_db(pool).await?;
320            } else {
321                return Err(e.into());
322            }
323        }
324
325        Ok(order)
326    }
327
328    async fn insert_db(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> {
329        sqlx::query(
330            r#"
331			      INSERT INTO orders (id, kind, status, amount, min_amount, max_amount,
332			      fiat_code, fiat_amount, payment_method, premium, trade_keys,
333			      counterparty_pubkey, is_mine, buyer_invoice, request_id, created_at, expires_at)
334			      VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
335			    "#,
336        )
337        .bind(&self.id)
338        .bind(&self.kind)
339        .bind(&self.status)
340        .bind(self.amount)
341        .bind(self.min_amount)
342        .bind(self.max_amount)
343        .bind(&self.fiat_code)
344        .bind(self.fiat_amount)
345        .bind(&self.payment_method)
346        .bind(self.premium)
347        .bind(&self.trade_keys)
348        .bind(&self.counterparty_pubkey)
349        .bind(self.is_mine)
350        .bind(&self.buyer_invoice)
351        .bind(self.request_id)
352        .bind(self.created_at)
353        .bind(self.expires_at)
354        .execute(pool)
355        .await?
356        .rows_affected();
357        Ok(())
358    }
359
360    async fn update_db(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> {
361        sqlx::query(
362			r#"
363			  UPDATE orders 
364			  SET kind = ?, status = ?, amount = ?, min_amount = ?, max_amount = ?,
365			      fiat_code = ?, fiat_amount = ?, payment_method = ?, premium = ?, trade_keys = ?,
366			      counterparty_pubkey = ?, is_mine = ?, buyer_invoice = ?, request_id = ?, created_at = ?, expires_at = ?
367			  WHERE id = ?
368			"#,
369		)
370		.bind(&self.kind)
371		.bind(&self.status)
372		.bind(self.amount)
373		.bind(self.min_amount)
374		.bind(self.max_amount)
375		.bind(&self.fiat_code)
376		.bind(self.fiat_amount)
377		.bind(&self.payment_method)
378		.bind(self.premium)
379		.bind(&self.trade_keys)
380		.bind(&self.counterparty_pubkey)
381		.bind(self.is_mine)
382		.bind(&self.buyer_invoice)
383		.bind(self.request_id)
384		.bind(self.created_at)
385		.bind(self.expires_at)
386		.bind(&self.id)
387		.execute(pool)
388		.await?
389		.rows_affected();
390        Ok(())
391    }
392
393    // Setters encadenables
394    pub fn set_kind(&mut self, kind: String) -> &mut Self {
395        self.kind = Some(kind);
396        self
397    }
398
399    pub fn set_status(&mut self, status: String) -> &mut Self {
400        self.status = Some(status);
401        self
402    }
403
404    pub fn set_amount(&mut self, amount: i64) -> &mut Self {
405        self.amount = amount;
406        self
407    }
408
409    pub fn set_fiat_code(&mut self, fiat_code: String) -> &mut Self {
410        self.fiat_code = fiat_code;
411        self
412    }
413
414    pub fn set_min_amount(&mut self, min_amount: i64) -> &mut Self {
415        self.min_amount = Some(min_amount);
416        self
417    }
418
419    pub fn set_max_amount(&mut self, max_amount: i64) -> &mut Self {
420        self.max_amount = Some(max_amount);
421        self
422    }
423
424    pub fn set_fiat_amount(&mut self, fiat_amount: i64) -> &mut Self {
425        self.fiat_amount = fiat_amount;
426        self
427    }
428
429    pub fn set_payment_method(&mut self, payment_method: String) -> &mut Self {
430        self.payment_method = payment_method;
431        self
432    }
433
434    pub fn set_premium(&mut self, premium: i64) -> &mut Self {
435        self.premium = premium;
436        self
437    }
438
439    pub fn set_counterparty_pubkey(&mut self, counterparty_pubkey: String) -> &mut Self {
440        self.counterparty_pubkey = Some(counterparty_pubkey);
441        self
442    }
443
444    pub fn set_trade_keys(&mut self, trade_keys: String) -> &mut Self {
445        self.trade_keys = Some(trade_keys);
446        self
447    }
448
449    pub fn set_is_mine(&mut self, is_mine: bool) -> &mut Self {
450        self.is_mine = Some(is_mine);
451        self
452    }
453
454    // Applying changes to the database
455    pub async fn save(&self, pool: &SqlitePool) -> Result<()> {
456        // Validation if an identity document is present
457        if let Some(ref id) = self.id {
458            sqlx::query(
459                r#"
460              UPDATE orders 
461              SET kind = ?, status = ?, amount = ?, fiat_code = ?, min_amount = ?, max_amount = ?, 
462                  fiat_amount = ?, payment_method = ?, premium = ?, trade_keys = ?, counterparty_pubkey = ?,
463                  is_mine = ?, buyer_invoice = ?, expires_at = ?
464              WHERE id = ?
465              "#,
466            )
467            .bind(&self.kind)
468            .bind(&self.status)
469            .bind(self.amount)
470            .bind(&self.fiat_code)
471            .bind(self.min_amount)
472            .bind(self.max_amount)
473            .bind(self.fiat_amount)
474            .bind(&self.payment_method)
475            .bind(self.premium)
476            .bind(&self.trade_keys)
477            .bind(&self.counterparty_pubkey)
478            .bind(self.is_mine)
479            .bind(&self.buyer_invoice)
480            .bind(self.expires_at)
481            .bind(id)
482            .execute(pool)
483            .await?;
484
485            println!("Order with id {} updated in the database.", id);
486        } else {
487            return Err(anyhow::anyhow!("Order must have an ID to be updated."));
488        }
489
490        Ok(())
491    }
492
493    pub async fn save_new_id(
494        pool: &SqlitePool,
495        id: String,
496        new_id: String,
497    ) -> anyhow::Result<bool> {
498        let rows_affected = sqlx::query(
499            r#"
500          UPDATE orders
501          SET id = ?
502          WHERE id = ?
503        "#,
504        )
505        .bind(&new_id)
506        .bind(&id)
507        .execute(pool)
508        .await?
509        .rows_affected();
510
511        Ok(rows_affected > 0)
512    }
513
514    pub async fn get_by_id(pool: &SqlitePool, id: &str) -> Result<Order> {
515        let order = sqlx::query_as::<_, Order>(
516            r#"
517            SELECT * FROM orders WHERE id = ?
518            LIMIT 1
519            "#,
520        )
521        .bind(id)
522        .fetch_one(pool)
523        .await?;
524
525        if order.id.is_none() {
526            return Err(anyhow::anyhow!("Order not found"));
527        }
528
529        Ok(order)
530    }
531
532    pub async fn get_all_trade_keys(pool: &SqlitePool) -> Result<Vec<String>> {
533        let trade_keys: Vec<String> = sqlx::query_scalar::<_, Option<String>>(
534            "SELECT DISTINCT trade_keys FROM orders WHERE trade_keys IS NOT NULL",
535        )
536        .fetch_all(pool)
537        .await?
538        .into_iter()
539        .flatten()
540        .collect();
541
542        Ok(trade_keys)
543    }
544
545    pub async fn delete_by_id(pool: &SqlitePool, id: &str) -> Result<bool> {
546        let rows_affected = sqlx::query(
547            r#"
548          DELETE FROM orders
549          WHERE id = ?
550        "#,
551        )
552        .bind(id)
553        .execute(pool)
554        .await?
555        .rows_affected();
556
557        Ok(rows_affected > 0)
558    }
559}