debot_db/
transaction_log.rs

1// transaction_log.rs
2
3use bson::doc;
4use bson::Bson;
5use bson::Document;
6use debot_utils::get_local_time;
7use debot_utils::HasId;
8use mongodb::Collection;
9use mongodb::{
10    options::{ClientOptions, Tls, TlsOptions},
11    Database,
12};
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use shared_mongodb::{database, ClientHolder};
16use std::collections::HashMap;
17use std::error;
18use std::fs::File;
19use std::io::{Read, Write};
20use std::path::Path;
21use std::sync::Arc;
22use std::time::SystemTime;
23use tokio::sync::Mutex;
24
25use crate::delete_item_all;
26use crate::SearchMode;
27use crate::TradingStrategy;
28use crate::{
29    create_unique_index, insert_item, search_item, search_items, update_item, Counter, CounterType,
30    Entity,
31};
32
33async fn get_last_id<T: Default + Entity + HasId>(db: &Database) -> u32 {
34    let item = T::default();
35    match search_items(
36        db,
37        &item,
38        crate::SearchMode::Descending,
39        Some(1),
40        None,
41        Some("id"),
42    )
43    .await
44    {
45        Ok(mut items) => items.pop().and_then(|item| item.id()).unwrap_or(0),
46        Err(e) => {
47            log::info!("get_last_id: {:?}", e);
48            0
49        }
50    }
51}
52
53#[derive(Serialize, Deserialize, Clone, Debug)]
54pub enum SampleTerm {
55    TradingTerm,
56    ShortTerm,
57    LongTerm,
58}
59
60impl SampleTerm {
61    pub fn to_numeric(&self) -> Decimal {
62        match self {
63            SampleTerm::TradingTerm => Decimal::new(1, 0),
64            SampleTerm::ShortTerm => Decimal::new(2, 0),
65            SampleTerm::LongTerm => Decimal::new(3, 0),
66        }
67    }
68}
69
70#[derive(Serialize, Deserialize, Clone, Debug)]
71pub struct FundConfig {
72    pub token: String,
73    pub trading_strategy: TradingStrategy,
74    pub balance_per_strategy: Decimal,
75    pub risk_reward: Decimal,
76    pub take_profit_ratio: Option<Decimal>,
77    pub atr_spread: Option<Decimal>,
78    pub atr_term: SampleTerm,
79    pub open_minutes: i64,
80    pub flash_crash_atr_multiplier: Option<Decimal>,
81    pub order_size_multiplier: Decimal,
82    pub flash_crash_min_percentage: Option<Decimal>,
83    pub breakout_atr_multiplier: Decimal,
84}
85
86#[derive(Serialize, Deserialize, Clone, Debug)]
87pub struct AppState {
88    pub id: u32,
89    pub last_execution_time: Option<SystemTime>,
90    pub last_equity: Option<Decimal>,
91    pub ave_dd: Option<Decimal>,
92    pub max_dd: Option<Decimal>,
93    pub cumulative_return: Decimal,
94    pub cumulative_dd: Decimal,
95    pub score: Option<Decimal>,
96    pub score_2: Option<Decimal>,
97    pub score_3: Option<Decimal>,
98    pub curcuit_break: bool,
99    pub error_time: Vec<String>,
100    pub max_invested_amount: Decimal,
101    pub fund_configs: Option<Vec<FundConfig>>,
102}
103
104impl Default for AppState {
105    fn default() -> Self {
106        Self {
107            id: 1,
108            last_execution_time: None,
109            last_equity: None,
110            ave_dd: None,
111            max_dd: None,
112            cumulative_return: Decimal::ZERO,
113            cumulative_dd: Decimal::ZERO,
114            score: None,
115            score_2: None,
116            score_3: None,
117            curcuit_break: false,
118            error_time: vec![],
119            max_invested_amount: Decimal::ZERO,
120            fund_configs: Some(vec![]),
121        }
122    }
123}
124
125#[derive(Serialize, Deserialize, Clone, Debug, Default)]
126pub struct PnlLog {
127    pub id: Option<u32>,
128    pub date: String,
129    pub pnl: Decimal,
130}
131
132impl HasId for PnlLog {
133    fn id(&self) -> Option<u32> {
134        self.id
135    }
136}
137
138#[derive(Serialize, Deserialize, Debug, Clone, Default)]
139pub struct PricePoint {
140    pub timestamp: i64,
141    pub timestamp_str: String,
142    pub price: Decimal,
143    pub volume: Option<Decimal>,
144    pub num_trades: Option<u64>,
145    pub funding_rate: Option<Decimal>,
146    pub open_interest: Option<Decimal>,
147    pub oracle_price: Option<Decimal>,
148    pub debug: DebugLog,
149}
150
151impl PricePoint {
152    pub fn new(
153        price: Decimal,
154        timestamp: Option<i64>,
155        volume: Option<Decimal>,
156        num_trades: Option<u64>,
157        funding_rate: Option<Decimal>,
158        open_interest: Option<Decimal>,
159        oracle_price: Option<Decimal>,
160        debug: DebugLog,
161    ) -> Self {
162        let (local_timestamp, timestamp_str) = get_local_time();
163        let timestamp = timestamp.unwrap_or(local_timestamp);
164        Self {
165            timestamp,
166            timestamp_str,
167            price,
168            volume,
169            num_trades,
170            funding_rate,
171            open_interest,
172            oracle_price,
173            debug,
174        }
175    }
176}
177
178#[derive(Serialize, Deserialize, Clone, Debug, Default)]
179pub struct PriceLog {
180    pub id: Option<u32>,
181    pub name: String,
182    pub token_name: String,
183    pub price_point: PricePoint,
184}
185
186impl HasId for PriceLog {
187    fn id(&self) -> Option<u32> {
188        self.id
189    }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
193pub enum CandlePattern {
194    #[default]
195    None,
196    Hammer,
197    InvertedHammer,
198    BullishEngulfing,
199    BearishEngulfing,
200    Doji,
201    Marubozu,
202    MorningStar,
203    EveningStar,
204    ThreeWhiteSoldiers,
205    ThreeBlackCrows,
206    PiercingPattern,
207    DarkCloudCover,
208    Harami,
209    HaramiCross,
210    SpinningTop,
211}
212
213impl CandlePattern {
214    pub fn to_one_hot(&self) -> [Decimal; 16] {
215        let mut one_hot = [Decimal::ZERO; 16];
216
217        match self {
218            CandlePattern::None => one_hot[0] = Decimal::ONE,
219            CandlePattern::Hammer => one_hot[1] = Decimal::ONE,
220            CandlePattern::InvertedHammer => one_hot[2] = Decimal::ONE,
221            CandlePattern::BullishEngulfing => one_hot[3] = Decimal::ONE,
222            CandlePattern::BearishEngulfing => one_hot[4] = Decimal::ONE,
223            CandlePattern::Doji => one_hot[5] = Decimal::ONE,
224            CandlePattern::Marubozu => one_hot[6] = Decimal::ONE,
225            CandlePattern::MorningStar => one_hot[7] = Decimal::ONE,
226            CandlePattern::EveningStar => one_hot[8] = Decimal::ONE,
227            CandlePattern::ThreeWhiteSoldiers => one_hot[9] = Decimal::ONE,
228            CandlePattern::ThreeBlackCrows => one_hot[10] = Decimal::ONE,
229            CandlePattern::PiercingPattern => one_hot[11] = Decimal::ONE,
230            CandlePattern::DarkCloudCover => one_hot[12] = Decimal::ONE,
231            CandlePattern::Harami => one_hot[13] = Decimal::ONE,
232            CandlePattern::HaramiCross => one_hot[14] = Decimal::ONE,
233            CandlePattern::SpinningTop => one_hot[15] = Decimal::ONE,
234        }
235
236        one_hot
237    }
238}
239
240#[derive(Serialize, Deserialize, Clone, Debug, Default)]
241pub struct DebugLog {
242    pub input_1: Decimal,
243    pub input_2: Decimal,
244    pub input_3: Decimal,
245    pub input_4: Decimal,
246    pub input_5: Decimal,
247    pub input_6: Decimal,
248    pub input_7: Decimal,
249    pub input_8: Decimal,
250    pub input_9: Decimal,
251    pub input_10: Decimal,
252    pub input_11: Decimal,
253    pub input_12: Decimal,
254    pub input_13: Decimal,
255    pub input_14: Decimal,
256    pub input_15: Decimal,
257    pub input_16: Decimal,
258    pub input_17: Decimal,
259    pub input_18: Decimal,
260    pub input_19: Decimal,
261    pub input_20: Decimal,
262    pub input_21: Decimal,
263    pub input_22: Decimal,
264    pub input_23: Decimal,
265    pub input_24: Decimal,
266    pub input_25: Decimal,
267    pub input_26: Decimal,
268    pub input_27: Decimal,
269    pub input_28: Decimal,
270    pub input_29: Decimal,
271    pub input_30: CandlePattern,
272    pub input_31: CandlePattern,
273    pub input_32: CandlePattern,
274    pub input_33: CandlePattern,
275    pub input_34: CandlePattern,
276    pub input_35: CandlePattern,
277    pub input_36: CandlePattern,
278    pub input_37: CandlePattern,
279    pub input_38: CandlePattern,
280    pub input_39: CandlePattern,
281    pub output_1: Decimal,
282    pub output_2: Decimal,
283    pub output_3: Option<Decimal>,
284    pub output_4: Option<Decimal>,
285    pub output_5: Option<Decimal>,
286}
287
288#[derive(Serialize, Deserialize, Clone, Debug, Default)]
289pub struct PositionLog {
290    pub id: Option<u32>,
291    pub fund_name: String,
292    pub order_id: String,
293    pub ordered_price: Decimal,
294    pub state: String,
295    pub token_name: String,
296    pub open_time_str: String,
297    pub open_timestamp: i64,
298    pub close_time_str: String,
299    pub average_open_price: Decimal,
300    pub position_type: String,
301    pub close_price: Decimal,
302    pub asset_in_usd: Decimal,
303    pub pnl: Decimal,
304    pub fee: Decimal,
305    pub debug: DebugLog,
306}
307
308#[derive(Serialize, Deserialize)]
309pub struct SerializableModel {
310    pub model: Vec<u8>,
311}
312
313impl HasId for PositionLog {
314    fn id(&self) -> Option<u32> {
315        self.id
316    }
317}
318
319pub struct TransactionLog {
320    counter: Counter,
321    db_r_name: String,
322    db_w_name: String,
323    client_holder: Arc<Mutex<ClientHolder>>,
324}
325
326impl TransactionLog {
327    pub async fn new(
328        max_position_counter: Option<u32>,
329        max_price_counter: Option<u32>,
330        max_pnl_counter: Option<u32>,
331        mongodb_uri: &str,
332        db_r_name: &str,
333        db_w_name: &str,
334        back_test: bool,
335    ) -> Self {
336        // Set up the DB client holder
337        let mut client_options = match ClientOptions::parse(mongodb_uri).await {
338            Ok(client_options) => client_options,
339            Err(e) => {
340                panic!("{:?}", e);
341            }
342        };
343        let tls_options = TlsOptions::builder().build();
344        client_options.tls = Some(Tls::Enabled(tls_options));
345        let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
346
347        // Get database instances for read and write
348        let db_w = shared_mongodb::database::get(&client_holder, &db_w_name)
349            .await
350            .unwrap();
351        let db_r = shared_mongodb::database::get(&client_holder, &db_r_name)
352            .await
353            .unwrap();
354
355        // Ensure indexes exist in both read and write databases
356        create_unique_index(&db_w)
357            .await
358            .expect("Error creating unique index in db_w");
359        create_unique_index(&db_r)
360            .await
361            .expect("Error creating unique index in db_r");
362
363        if back_test {
364            if let Err(e) = Self::delete_all_positions(&db_w).await {
365                panic!("delete_all_positions failed: {:?}", e);
366            }
367            if let Err(e) = Self::delete_app_state(&db_w).await {
368                panic!("delete_app_state failed: {:?}", e);
369            }
370        }
371
372        let last_position_counter =
373            TransactionLog::get_last_transaction_id(&db_w, CounterType::Position).await;
374        let last_price_counter =
375            TransactionLog::get_last_transaction_id(&db_w, CounterType::Price).await;
376        let last_pnl_counter =
377            TransactionLog::get_last_transaction_id(&db_w, CounterType::Pnl).await;
378
379        let counter = Counter::new(
380            max_position_counter,
381            max_price_counter,
382            max_pnl_counter,
383            last_position_counter,
384            last_price_counter,
385            last_pnl_counter,
386        );
387
388        log::warn!(
389            "position = {}/{:?}, price = {}/{:?}, pnl = {}/{:?}",
390            last_position_counter,
391            max_position_counter,
392            last_price_counter,
393            max_price_counter,
394            last_pnl_counter,
395            max_pnl_counter,
396        );
397
398        TransactionLog {
399            counter,
400            db_r_name: db_r_name.to_owned(),
401            db_w_name: db_w_name.to_owned(),
402            client_holder,
403        }
404    }
405
406    pub fn increment_counter(&self, counter_type: CounterType) -> u32 {
407        self.counter.increment(counter_type)
408    }
409
410    pub async fn get_last_transaction_id(db: &Database, counter_type: CounterType) -> u32 {
411        match counter_type {
412            CounterType::Position => get_last_id::<PositionLog>(db).await,
413            CounterType::Price => get_last_id::<PriceLog>(db).await,
414            CounterType::Pnl => get_last_id::<PnlLog>(db).await,
415        }
416    }
417
418    pub async fn get_w_db(&self) -> Option<Database> {
419        self.get_db(false).await
420    }
421
422    pub async fn get_r_db(&self) -> Option<Database> {
423        self.get_db(true).await
424    }
425
426    async fn get_db(&self, read: bool) -> Option<Database> {
427        let db_name = if read {
428            &self.db_r_name
429        } else {
430            &self.db_w_name
431        };
432        let db = match database::get(&self.client_holder, db_name).await {
433            Ok(db) => Some(db),
434            Err(e) => {
435                log::error!("get_db: {:?}", e);
436                None
437            }
438        };
439        db
440    }
441
442    pub async fn update_transaction(
443        db: &Database,
444        item: &PositionLog,
445    ) -> Result<(), Box<dyn error::Error>> {
446        update_item(db, item).await?;
447        Ok(())
448    }
449
450    pub async fn update_price(db: &Database, item: PriceLog) -> Result<(), Box<dyn error::Error>> {
451        update_item(db, &item).await?;
452        Ok(())
453    }
454
455    pub async fn copy_price(db_r: &Database, db_w: &Database, limit: Option<u32>) {
456        let item = PriceLog::default();
457        let items = {
458            match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
459                Ok(items) => items,
460                Err(e) => {
461                    log::error!("get price: {:?}", e);
462                    return;
463                }
464            }
465        };
466        log::debug!("get prices: num = {}", items.len());
467
468        for item in &items {
469            match insert_item(db_w, item).await {
470                Ok(_) => {}
471                Err(e) => {
472                    log::error!("write price: {:?}", e);
473                    return;
474                }
475            }
476        }
477    }
478
479    pub async fn copy_position(db_r: &Database, db_w: &Database, limit: Option<u32>) {
480        let item = PositionLog::default();
481        let items = {
482            match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
483                Ok(items) => items,
484                Err(e) => {
485                    log::error!("get position: {:?}", e);
486                    return;
487                }
488            }
489        };
490        log::debug!("get position: num = {}", items.len());
491
492        for item in &items {
493            match insert_item(db_w, item).await {
494                Ok(_) => {}
495                Err(e) => {
496                    log::error!("write position: {:?}", e);
497                    return;
498                }
499            }
500        }
501    }
502
503    pub async fn get_price_market_data(
504        db: &Database,
505        limit: Option<u32>,
506        id: Option<u32>,
507        is_ascend: bool,
508    ) -> HashMap<String, HashMap<String, Vec<PricePoint>>> {
509        let search_mode = if is_ascend {
510            SearchMode::Ascending
511        } else {
512            SearchMode::Descending
513        };
514        let sort_key = Some("price_point.timestamp");
515        let item = PriceLog::default();
516
517        let items = match id {
518            Some(id) => search_item(db, &item, Some(id), sort_key)
519                .await
520                .map(|item| vec![item]),
521            None => search_items(db, &item, search_mode, limit, None, sort_key).await,
522        };
523
524        let Ok(mut items) = items else {
525            log::warn!("get_price_market_data: search failed");
526            return HashMap::new();
527        };
528
529        items.sort_by_key(|p| p.price_point.timestamp);
530
531        let mut result = HashMap::new();
532        for price_log in items {
533            result
534                .entry(price_log.name)
535                .or_insert_with(HashMap::new)
536                .entry(price_log.token_name)
537                .or_insert_with(Vec::new)
538                .push(price_log.price_point);
539        }
540
541        result
542    }
543
544    pub async fn get_all_positions(
545        db: &Database,
546        limit: Option<u32>,
547        id: Option<u32>,
548        is_ascend: bool,
549    ) -> Vec<PositionLog> {
550        let search_mode = if is_ascend {
551            SearchMode::Ascending
552        } else {
553            SearchMode::Descending
554        };
555        let sort_key = Some("open_timestamp");
556        let item = PositionLog::default();
557
558        let items = if let Some(id) = id {
559            match search_item(db, &item, Some(id), sort_key).await {
560                Ok(position) => vec![position],
561                Err(e) => {
562                    log::warn!("get_all_positions: {:?}", e);
563                    vec![]
564                }
565            }
566        } else {
567            match search_items(db, &item, search_mode, limit, None, sort_key).await {
568                Ok(positions) => positions,
569                Err(e) => {
570                    log::warn!("get_all_positions: {:?}", e);
571                    vec![]
572                }
573            }
574        };
575
576        items
577    }
578
579    async fn delete_all_positions(db: &Database) -> Result<(), Box<dyn error::Error>> {
580        let item = PositionLog::default();
581        delete_item_all(db, &item).await
582    }
583
584    pub async fn insert_pnl(db: &Database, item: PnlLog) -> Result<(), Box<dyn error::Error>> {
585        insert_item(db, &item).await?;
586        Ok(())
587    }
588
589    pub async fn get_app_state(db: &Database) -> AppState {
590        let item = AppState::default();
591        match search_item(db, &item, Some(1), Some("id")).await {
592            Ok(item) => item,
593            Err(e) => {
594                log::warn!("get_app_state: {:?}", e);
595                item
596            }
597        }
598    }
599
600    async fn delete_app_state(db: &Database) -> Result<(), Box<dyn error::Error>> {
601        let item = AppState::default();
602        delete_item_all(db, &item).await
603    }
604
605    pub async fn update_app_state(
606        db: &Database,
607        last_execution_time: Option<SystemTime>,
608        last_equity: Option<Decimal>,
609        ave_dd: Option<Decimal>,
610        max_dd: Option<Decimal>,
611        cumulative_return: Option<Decimal>,
612        cumulative_dd: Option<Decimal>,
613        score: Option<Decimal>,
614        score_2: Option<Decimal>,
615        score_3: Option<Decimal>,
616        curcuit_break: bool,
617        error_time: Option<String>,
618        max_invested_amount: Option<Decimal>,
619        fund_configs: Option<Vec<FundConfig>>,
620    ) -> Result<(), Box<dyn error::Error>> {
621        let item = AppState::default();
622        let mut item = match search_item(db, &item, Some(1), Some("id")).await {
623            Ok(prev_item) => prev_item,
624            Err(_) => item,
625        };
626
627        if last_execution_time.is_some() {
628            item.last_execution_time = last_execution_time;
629        }
630
631        if let Some(last_equity) = last_equity {
632            item.last_equity = Some(last_equity.round());
633        }
634
635        if let Some(ave_dd) = ave_dd {
636            item.ave_dd = Some(ave_dd.round());
637        }
638
639        if let Some(max_dd_val) = max_dd {
640            if item
641                .max_dd
642                .map_or(true, |item_max_dd| max_dd_val > item_max_dd)
643            {
644                item.max_dd = Some(max_dd_val.round());
645            }
646        }
647
648        if let Some(cumulative_return) = cumulative_return {
649            item.cumulative_return += cumulative_return.round();
650        }
651
652        if let Some(cumulative_dd) = cumulative_dd {
653            item.cumulative_dd += cumulative_dd.round();
654        }
655
656        if score.is_some() {
657            item.score = score;
658        }
659
660        if score_2.is_some() {
661            item.score_2 = score_2;
662        }
663
664        if score_3.is_some() {
665            item.score_3 = score_3;
666        }
667
668        item.curcuit_break = curcuit_break;
669
670        if let Some(error_time) = error_time {
671            item.error_time.push(error_time);
672        }
673
674        if let Some(max_invested_amount) = max_invested_amount {
675            item.max_invested_amount = max_invested_amount.round();
676        }
677
678        if let Some(fund_configs) = fund_configs {
679            item.fund_configs = Some(fund_configs);
680        }
681
682        update_item(db, &item).await?;
683        Ok(())
684    }
685
686    pub fn db_w_name(&self) -> &str {
687        &self.db_w_name
688    }
689}
690
691#[derive(Clone)]
692pub struct ModelParams {
693    db_name: String,
694    client_holder: Arc<Mutex<ClientHolder>>,
695    collection_name: String,
696    save_to_db: bool,
697    file_path: Option<String>,
698}
699
700impl ModelParams {
701    pub async fn new(
702        mongodb_uri: &str,
703        db_name: &str,
704        save_to_db: bool,
705        file_path: Option<String>,
706    ) -> Self {
707        // Set up the DB client holder
708        let mut client_options = match ClientOptions::parse(mongodb_uri).await {
709            Ok(client_options) => client_options,
710            Err(e) => {
711                panic!("{:?}", e);
712            }
713        };
714        let tls_options = TlsOptions::builder().build();
715        client_options.tls = Some(Tls::Enabled(tls_options));
716        let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
717
718        ModelParams {
719            db_name: db_name.to_owned(),
720            client_holder,
721            collection_name: "model_params".to_owned(),
722            save_to_db,
723            file_path,
724        }
725    }
726
727    async fn get_db(&self) -> Option<Database> {
728        let db = match database::get(&self.client_holder, &self.db_name).await {
729            Ok(db) => Some(db),
730            Err(e) => {
731                log::error!("get_db: {:?}", e);
732                None
733            }
734        };
735        db
736    }
737
738    pub async fn save_model(
739        &self,
740        key: &str,
741        model: &SerializableModel,
742    ) -> Result<(), Box<dyn std::error::Error>> {
743        if self.save_to_db {
744            self.save_model_to_db(key, model).await
745        } else {
746            self.save_model_to_file(key, model).await
747        }
748    }
749
750    pub async fn load_model(
751        &self,
752        key: &str,
753    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
754        if self.save_to_db {
755            self.load_model_from_db(key).await
756        } else {
757            self.load_model_from_file(key).await
758        }
759    }
760
761    async fn save_model_to_db(
762        &self,
763        key: &str,
764        model: &SerializableModel,
765    ) -> Result<(), Box<dyn std::error::Error>> {
766        let db = self.get_db().await.ok_or("no db")?;
767        let collection: Collection<Document> = db.collection(&self.collection_name);
768        let serialized_model = bincode::serialize(model)?;
769
770        let document = doc! {
771            "key": key,
772            "model": Bson::Binary(mongodb::bson::Binary {
773                subtype: mongodb::bson::spec::BinarySubtype::Generic,
774                bytes: serialized_model
775            })
776        };
777
778        collection
779            .update_one(
780                doc! { "key": key },
781                doc! { "$set": document },
782                mongodb::options::UpdateOptions::builder()
783                    .upsert(true)
784                    .build(),
785            )
786            .await?;
787        Ok(())
788    }
789
790    async fn load_model_from_db(
791        &self,
792        key: &str,
793    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
794        let db = self.get_db().await.ok_or("no db")?;
795        let collection: Collection<Document> = db.collection(&self.collection_name);
796
797        let filter = doc! { "key": key };
798        let document = collection
799            .find_one(filter, None)
800            .await?
801            .ok_or("No model found in the collection")?;
802
803        if let Some(Bson::Binary(model_bytes)) = document.get("model") {
804            let model: SerializableModel = bincode::deserialize(&model_bytes.bytes)?;
805            Ok(model)
806        } else {
807            Err("Invalid data format".into())
808        }
809    }
810
811    async fn save_model_to_file(
812        &self,
813        key: &str,
814        model: &SerializableModel,
815    ) -> Result<(), Box<dyn std::error::Error>> {
816        let serialized_model = bincode::serialize(model)?;
817        let file_name = format!("{}.bin", key);
818
819        let file_path = if let Some(ref dir) = self.file_path {
820            Path::new(dir).join(file_name)
821        } else {
822            Path::new(&file_name).to_path_buf()
823        };
824
825        let mut file = File::create(&file_path)?;
826        file.write_all(&serialized_model)?;
827        Ok(())
828    }
829
830    async fn load_model_from_file(
831        &self,
832        key: &str,
833    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
834        let file_name = format!("{}.bin", key);
835
836        let file_path = if let Some(ref dir) = self.file_path {
837            Path::new(dir).join(file_name)
838        } else {
839            Path::new(&file_name).to_path_buf()
840        };
841
842        let mut file = File::open(&file_path).map_err(|e| {
843            log::error!("Failed to open file: {:?}: {}", file_path, e);
844            e
845        })?;
846        let mut buffer = Vec::new();
847        file.read_to_end(&mut buffer)?;
848        let model: SerializableModel = bincode::deserialize(&buffer)?;
849        Ok(model)
850    }
851}