debot_db/
transaction_log.rs

1// transaction_log.rs
2
3use bson::doc;
4use bson::Bson;
5use bson::Document;
6use debot_utils::get_local_time;
7use debot_utils::HasId;
8use mongodb::Collection;
9use mongodb::{
10    options::{ClientOptions, Tls, TlsOptions},
11    Database,
12};
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use shared_mongodb::{database, ClientHolder};
16use std::collections::HashMap;
17use std::error;
18use std::fs::File;
19use std::io::{Read, Write};
20use std::path::Path;
21use std::sync::Arc;
22use std::time::SystemTime;
23use tokio::sync::Mutex;
24
25use crate::delete_item_all;
26use crate::SearchMode;
27use crate::TradingStrategy;
28use crate::{
29    create_unique_index, insert_item, search_item, search_items, update_item, Counter, CounterType,
30    Entity,
31};
32
33fn default_input_40() -> Decimal {
34    Decimal::ONE
35}
36
37async fn get_last_id<T: Default + Entity + HasId>(db: &Database) -> u32 {
38    let item = T::default();
39    match search_items(
40        db,
41        &item,
42        crate::SearchMode::Descending,
43        Some(1),
44        None,
45        Some("id"),
46    )
47    .await
48    {
49        Ok(mut items) => items.pop().and_then(|item| item.id()).unwrap_or(0),
50        Err(e) => {
51            log::info!("get_last_id: {:?}", e);
52            0
53        }
54    }
55}
56
57#[derive(Serialize, Deserialize, Clone, Debug)]
58pub enum SampleTerm {
59    TradingTerm,
60    ShortTerm,
61    LongTerm,
62}
63
64impl SampleTerm {
65    pub fn to_numeric(&self) -> Decimal {
66        match self {
67            SampleTerm::TradingTerm => Decimal::new(1, 0),
68            SampleTerm::ShortTerm => Decimal::new(2, 0),
69            SampleTerm::LongTerm => Decimal::new(3, 0),
70        }
71    }
72}
73
74#[derive(Serialize, Deserialize, Clone, Debug)]
75pub struct FundConfig {
76    pub token: String,
77    pub trading_strategy: TradingStrategy,
78    pub balance_per_strategy: Decimal,
79    pub risk_reward: Decimal,
80    pub take_profit_ratio: Option<Decimal>,
81    pub atr_spread: Decimal,
82    pub atr_term: SampleTerm,
83    pub entry_timeout_sec: i64,
84    pub max_holding_sec: i64,
85    pub order_size_multiplier: Decimal,
86    pub tick_spread: i64,
87    pub bias_ticks: i64,
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug)]
91pub struct AppState {
92    pub id: u32,
93    pub last_execution_time: Option<SystemTime>,
94    pub last_equity: Option<Decimal>,
95    pub ave_dd: Option<Decimal>,
96    pub max_dd: Option<Decimal>,
97    pub cumulative_return: Decimal,
98    pub cumulative_dd: Decimal,
99    pub score: Option<Decimal>,
100    pub score_2: Option<Decimal>,
101    pub score_3: Option<Decimal>,
102    pub curcuit_break: bool,
103    pub error_time: Vec<String>,
104    pub max_invested_amount: Decimal,
105    pub fund_configs: Option<Vec<FundConfig>>,
106}
107
108impl Default for AppState {
109    fn default() -> Self {
110        Self {
111            id: 1,
112            last_execution_time: None,
113            last_equity: None,
114            ave_dd: None,
115            max_dd: None,
116            cumulative_return: Decimal::ZERO,
117            cumulative_dd: Decimal::ZERO,
118            score: None,
119            score_2: None,
120            score_3: None,
121            curcuit_break: false,
122            error_time: vec![],
123            max_invested_amount: Decimal::ZERO,
124            fund_configs: Some(vec![]),
125        }
126    }
127}
128
129#[derive(Serialize, Deserialize, Clone, Debug, Default)]
130pub struct PnlLog {
131    pub id: Option<u32>,
132    pub date: String,
133    pub pnl: Decimal,
134}
135
136impl HasId for PnlLog {
137    fn id(&self) -> Option<u32> {
138        self.id
139    }
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone, Default)]
143pub struct PricePoint {
144    pub timestamp: i64,
145    pub timestamp_str: String,
146    pub price: Decimal,
147    pub volume: Option<Decimal>,
148    pub num_trades: Option<u64>,
149    pub funding_rate: Option<Decimal>,
150    pub open_interest: Option<Decimal>,
151    pub oracle_price: Option<Decimal>,
152    pub debug: Option<DebugLog>,
153}
154
155impl PricePoint {
156    pub fn new(
157        price: Decimal,
158        timestamp: Option<i64>,
159        volume: Option<Decimal>,
160        num_trades: Option<u64>,
161        funding_rate: Option<Decimal>,
162        open_interest: Option<Decimal>,
163        oracle_price: Option<Decimal>,
164        debug: Option<DebugLog>,
165    ) -> Self {
166        let (local_timestamp, timestamp_str) = get_local_time();
167        let timestamp = timestamp.unwrap_or(local_timestamp);
168        Self {
169            timestamp,
170            timestamp_str,
171            price,
172            volume,
173            num_trades,
174            funding_rate,
175            open_interest,
176            oracle_price,
177            debug,
178        }
179    }
180}
181
182#[derive(Serialize, Deserialize, Clone, Debug, Default)]
183pub struct PriceLog {
184    pub id: Option<u32>,
185    pub name: String,
186    pub token_name: String,
187    pub price_point: PricePoint,
188}
189
190impl HasId for PriceLog {
191    fn id(&self) -> Option<u32> {
192        self.id
193    }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
197pub enum CandlePattern {
198    #[default]
199    None,
200    Hammer,
201    InvertedHammer,
202    BullishEngulfing,
203    BearishEngulfing,
204    Doji,
205    Marubozu,
206    MorningStar,
207    EveningStar,
208    ThreeWhiteSoldiers,
209    ThreeBlackCrows,
210    PiercingPattern,
211    DarkCloudCover,
212    Harami,
213    HaramiCross,
214    SpinningTop,
215}
216
217impl CandlePattern {
218    pub fn to_one_hot(&self) -> [Decimal; 16] {
219        let mut one_hot = [Decimal::ZERO; 16];
220
221        match self {
222            CandlePattern::None => one_hot[0] = Decimal::ONE,
223            CandlePattern::Hammer => one_hot[1] = Decimal::ONE,
224            CandlePattern::InvertedHammer => one_hot[2] = Decimal::ONE,
225            CandlePattern::BullishEngulfing => one_hot[3] = Decimal::ONE,
226            CandlePattern::BearishEngulfing => one_hot[4] = Decimal::ONE,
227            CandlePattern::Doji => one_hot[5] = Decimal::ONE,
228            CandlePattern::Marubozu => one_hot[6] = Decimal::ONE,
229            CandlePattern::MorningStar => one_hot[7] = Decimal::ONE,
230            CandlePattern::EveningStar => one_hot[8] = Decimal::ONE,
231            CandlePattern::ThreeWhiteSoldiers => one_hot[9] = Decimal::ONE,
232            CandlePattern::ThreeBlackCrows => one_hot[10] = Decimal::ONE,
233            CandlePattern::PiercingPattern => one_hot[11] = Decimal::ONE,
234            CandlePattern::DarkCloudCover => one_hot[12] = Decimal::ONE,
235            CandlePattern::Harami => one_hot[13] = Decimal::ONE,
236            CandlePattern::HaramiCross => one_hot[14] = Decimal::ONE,
237            CandlePattern::SpinningTop => one_hot[15] = Decimal::ONE,
238        }
239
240        one_hot
241    }
242}
243
244#[derive(Serialize, Deserialize, Clone, Debug, Default)]
245pub struct DebugLog {
246    pub input_1: Decimal,
247    pub input_2: Decimal,
248    pub input_3: Decimal,
249    pub input_4: Decimal,
250    pub input_5: Decimal,
251    pub input_6: Decimal,
252    pub input_7: Decimal,
253    pub input_8: Decimal,
254    pub input_9: Decimal,
255    pub input_10: Decimal,
256    pub input_11: Decimal,
257    pub input_12: Decimal,
258    pub input_13: Decimal,
259    pub input_14: Decimal,
260    pub input_15: Decimal,
261    pub input_16: Decimal,
262    pub input_17: Decimal,
263    pub input_18: Decimal,
264    pub input_19: Decimal,
265    pub input_20: Decimal,
266    pub input_21: Decimal,
267    pub input_22: Decimal,
268    pub input_23: Decimal,
269    pub input_24: Decimal,
270    pub input_25: Decimal,
271    pub input_26: Decimal,
272    pub input_27: Decimal,
273    pub input_28: Decimal,
274    pub input_29: Decimal,
275    pub input_30: CandlePattern,
276    pub input_31: CandlePattern,
277    pub input_32: CandlePattern,
278    pub input_33: CandlePattern,
279    pub input_34: CandlePattern,
280    pub input_35: CandlePattern,
281    pub input_36: CandlePattern,
282    pub input_37: CandlePattern,
283    pub input_38: CandlePattern,
284    pub input_39: CandlePattern,
285    #[serde(default = "default_input_40")]
286    pub input_40: Decimal, // volume_change_ratio for enhanced Inago detection
287    pub output_1: Decimal,
288    pub output_2: Decimal,
289    pub output_3: Decimal,
290    pub output_4: Decimal,
291    pub output_5: Decimal,
292}
293
294#[derive(Serialize, Deserialize, Clone, Debug, Default)]
295pub struct PositionLog {
296    pub id: Option<u32>,
297    pub fund_name: String,
298    pub order_id: String,
299    pub ordered_price: Decimal,
300    pub state: String,
301    pub token_name: String,
302    pub open_time_str: String,
303    pub open_timestamp: i64,
304    pub close_time_str: String,
305    pub average_open_price: Decimal,
306    pub position_type: String,
307    pub close_price: Decimal,
308    pub asset_in_usd: Decimal,
309    pub pnl: Decimal,
310    pub fee: Decimal,
311    pub debug: DebugLog,
312}
313
314#[derive(Serialize, Deserialize)]
315pub struct SerializableModel {
316    pub model: Vec<u8>,
317}
318
319impl HasId for PositionLog {
320    fn id(&self) -> Option<u32> {
321        self.id
322    }
323}
324
325pub struct TransactionLog {
326    counter: Counter,
327    db_r_name: String,
328    db_w_name: String,
329    client_holder: Arc<Mutex<ClientHolder>>,
330}
331
332impl TransactionLog {
333    pub async fn new(
334        max_position_counter: Option<u32>,
335        max_price_counter: Option<u32>,
336        max_pnl_counter: Option<u32>,
337        mongodb_uri: &str,
338        db_r_name: &str,
339        db_w_name: &str,
340        back_test: bool,
341    ) -> Self {
342        // Set up the DB client holder
343        let mut client_options = match ClientOptions::parse(mongodb_uri).await {
344            Ok(client_options) => client_options,
345            Err(e) => {
346                panic!("{:?}", e);
347            }
348        };
349        let tls_options = TlsOptions::builder().build();
350        client_options.tls = Some(Tls::Enabled(tls_options));
351        let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
352
353        // Get database instances for read and write
354        let db_w = shared_mongodb::database::get(&client_holder, &db_w_name)
355            .await
356            .unwrap();
357        let db_r = shared_mongodb::database::get(&client_holder, &db_r_name)
358            .await
359            .unwrap();
360
361        // Ensure indexes exist in both read and write databases
362        create_unique_index(&db_w)
363            .await
364            .expect("Error creating unique index in db_w");
365        create_unique_index(&db_r)
366            .await
367            .expect("Error creating unique index in db_r");
368
369        if back_test {
370            if let Err(e) = Self::delete_all_positions(&db_w).await {
371                panic!("delete_all_positions failed: {:?}", e);
372            }
373            if let Err(e) = Self::delete_app_state(&db_w).await {
374                panic!("delete_app_state failed: {:?}", e);
375            }
376        }
377
378        let last_position_counter =
379            TransactionLog::get_last_transaction_id(&db_w, CounterType::Position).await;
380        let last_price_counter =
381            TransactionLog::get_last_transaction_id(&db_w, CounterType::Price).await;
382        let last_pnl_counter =
383            TransactionLog::get_last_transaction_id(&db_w, CounterType::Pnl).await;
384
385        let counter = Counter::new(
386            max_position_counter,
387            max_price_counter,
388            max_pnl_counter,
389            last_position_counter,
390            last_price_counter,
391            last_pnl_counter,
392        );
393
394        log::warn!(
395            "position = {}/{:?}, price = {}/{:?}, pnl = {}/{:?}",
396            last_position_counter,
397            max_position_counter,
398            last_price_counter,
399            max_price_counter,
400            last_pnl_counter,
401            max_pnl_counter,
402        );
403
404        TransactionLog {
405            counter,
406            db_r_name: db_r_name.to_owned(),
407            db_w_name: db_w_name.to_owned(),
408            client_holder,
409        }
410    }
411
412    pub fn increment_counter(&self, counter_type: CounterType) -> u32 {
413        self.counter.increment(counter_type)
414    }
415
416    pub async fn get_last_transaction_id(db: &Database, counter_type: CounterType) -> u32 {
417        match counter_type {
418            CounterType::Position => get_last_id::<PositionLog>(db).await,
419            CounterType::Price => get_last_id::<PriceLog>(db).await,
420            CounterType::Pnl => get_last_id::<PnlLog>(db).await,
421        }
422    }
423
424    pub async fn get_w_db(&self) -> Option<Database> {
425        self.get_db(false).await
426    }
427
428    pub async fn get_r_db(&self) -> Option<Database> {
429        self.get_db(true).await
430    }
431
432    async fn get_db(&self, read: bool) -> Option<Database> {
433        let db_name = if read {
434            &self.db_r_name
435        } else {
436            &self.db_w_name
437        };
438        let db = match database::get(&self.client_holder, db_name).await {
439            Ok(db) => Some(db),
440            Err(e) => {
441                log::error!("get_db: {:?}", e);
442                None
443            }
444        };
445        db
446    }
447
448    pub async fn update_transaction(
449        db: &Database,
450        item: &PositionLog,
451    ) -> Result<(), Box<dyn error::Error>> {
452        update_item(db, item).await?;
453        Ok(())
454    }
455
456    pub async fn update_price(db: &Database, item: PriceLog) -> Result<(), Box<dyn error::Error>> {
457        update_item(db, &item).await?;
458        Ok(())
459    }
460
461    pub async fn copy_price(db_r: &Database, db_w: &Database, limit: Option<u32>) {
462        let item = PriceLog::default();
463        let items = {
464            match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
465                Ok(items) => items,
466                Err(e) => {
467                    log::error!("get price: {:?}", e);
468                    return;
469                }
470            }
471        };
472        log::debug!("get prices: num = {}", items.len());
473
474        for item in &items {
475            match insert_item(db_w, item).await {
476                Ok(_) => {}
477                Err(e) => {
478                    log::error!("write price: {:?}", e);
479                    return;
480                }
481            }
482        }
483    }
484
485    pub async fn copy_position(db_r: &Database, db_w: &Database, limit: Option<u32>) {
486        let item = PositionLog::default();
487        let items = {
488            match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
489                Ok(items) => items,
490                Err(e) => {
491                    log::error!("get position: {:?}", e);
492                    return;
493                }
494            }
495        };
496        log::debug!("get position: num = {}", items.len());
497
498        for item in &items {
499            match insert_item(db_w, item).await {
500                Ok(_) => {}
501                Err(e) => {
502                    log::error!("write position: {:?}", e);
503                    return;
504                }
505            }
506        }
507    }
508
509    pub async fn get_price_market_data(
510        db: &Database,
511        limit: Option<u32>,
512        id: Option<u32>,
513        is_ascend: bool,
514    ) -> HashMap<String, HashMap<String, Vec<PricePoint>>> {
515        let search_mode = if is_ascend {
516            SearchMode::Ascending
517        } else {
518            SearchMode::Descending
519        };
520        let sort_key = Some("price_point.timestamp");
521        let item = PriceLog::default();
522
523        let items = match id {
524            Some(id) => search_item(db, &item, Some(id), sort_key)
525                .await
526                .map(|item| vec![item]),
527            None => search_items(db, &item, search_mode, limit, None, sort_key).await,
528        };
529
530        let Ok(mut items) = items else {
531            log::warn!("get_price_market_data: search failed");
532            return HashMap::new();
533        };
534
535        items.sort_by_key(|p| p.price_point.timestamp);
536
537        let mut result = HashMap::new();
538        for price_log in items {
539            result
540                .entry(price_log.name)
541                .or_insert_with(HashMap::new)
542                .entry(price_log.token_name)
543                .or_insert_with(Vec::new)
544                .push(price_log.price_point);
545        }
546
547        result
548    }
549
550    pub async fn get_all_positions(
551        db: &Database,
552        limit: Option<u32>,
553        id: Option<u32>,
554        is_ascend: bool,
555    ) -> Vec<PositionLog> {
556        let search_mode = if is_ascend {
557            SearchMode::Ascending
558        } else {
559            SearchMode::Descending
560        };
561        let sort_key = Some("open_timestamp");
562        let item = PositionLog::default();
563
564        let items = if let Some(id) = id {
565            match search_item(db, &item, Some(id), sort_key).await {
566                Ok(position) => vec![position],
567                Err(e) => {
568                    log::warn!("get_all_positions: {:?}", e);
569                    vec![]
570                }
571            }
572        } else {
573            match search_items(db, &item, search_mode, limit, None, sort_key).await {
574                Ok(positions) => positions,
575                Err(e) => {
576                    log::warn!("get_all_positions: {:?}", e);
577                    vec![]
578                }
579            }
580        };
581
582        items
583    }
584
585    async fn delete_all_positions(db: &Database) -> Result<(), Box<dyn error::Error>> {
586        let item = PositionLog::default();
587        delete_item_all(db, &item).await
588    }
589
590    pub async fn insert_pnl(db: &Database, item: PnlLog) -> Result<(), Box<dyn error::Error>> {
591        insert_item(db, &item).await?;
592        Ok(())
593    }
594
595    pub async fn get_app_state(db: &Database) -> AppState {
596        let item = AppState::default();
597        match search_item(db, &item, Some(1), Some("id")).await {
598            Ok(item) => item,
599            Err(e) => {
600                log::warn!("get_app_state: {:?}", e);
601                item
602            }
603        }
604    }
605
606    async fn delete_app_state(db: &Database) -> Result<(), Box<dyn error::Error>> {
607        let item = AppState::default();
608        delete_item_all(db, &item).await
609    }
610
611    pub async fn update_app_state(
612        db: &Database,
613        last_execution_time: Option<SystemTime>,
614        last_equity: Option<Decimal>,
615        ave_dd: Option<Decimal>,
616        max_dd: Option<Decimal>,
617        cumulative_return: Option<Decimal>,
618        cumulative_dd: Option<Decimal>,
619        score: Option<Decimal>,
620        score_2: Option<Decimal>,
621        score_3: Option<Decimal>,
622        curcuit_break: bool,
623        error_time: Option<String>,
624        max_invested_amount: Option<Decimal>,
625        fund_configs: Option<Vec<FundConfig>>,
626    ) -> Result<(), Box<dyn error::Error>> {
627        let item = AppState::default();
628        let mut item = match search_item(db, &item, Some(1), Some("id")).await {
629            Ok(prev_item) => prev_item,
630            Err(_) => item,
631        };
632
633        if last_execution_time.is_some() {
634            item.last_execution_time = last_execution_time;
635        }
636
637        if let Some(last_equity) = last_equity {
638            item.last_equity = Some(last_equity.round());
639        }
640
641        if let Some(ave_dd) = ave_dd {
642            item.ave_dd = Some(ave_dd.round());
643        }
644
645        if let Some(max_dd_val) = max_dd {
646            if item
647                .max_dd
648                .map_or(true, |item_max_dd| max_dd_val > item_max_dd)
649            {
650                item.max_dd = Some(max_dd_val.round());
651            }
652        }
653
654        if let Some(cumulative_return) = cumulative_return {
655            item.cumulative_return += cumulative_return.round();
656        }
657
658        if let Some(cumulative_dd) = cumulative_dd {
659            item.cumulative_dd += cumulative_dd.round();
660        }
661
662        if score.is_some() {
663            item.score = score;
664        }
665
666        if score_2.is_some() {
667            item.score_2 = score_2;
668        }
669
670        if score_3.is_some() {
671            item.score_3 = score_3;
672        }
673
674        item.curcuit_break = curcuit_break;
675
676        if let Some(error_time) = error_time {
677            item.error_time.push(error_time);
678        }
679
680        if let Some(max_invested_amount) = max_invested_amount {
681            item.max_invested_amount = max_invested_amount.round();
682        }
683
684        if let Some(fund_configs) = fund_configs {
685            item.fund_configs = Some(fund_configs);
686        }
687
688        update_item(db, &item).await?;
689        Ok(())
690    }
691
692    pub fn db_w_name(&self) -> &str {
693        &self.db_w_name
694    }
695}
696
697#[derive(Clone)]
698pub struct ModelParams {
699    db_name: String,
700    client_holder: Arc<Mutex<ClientHolder>>,
701    collection_name: String,
702    save_to_db: bool,
703    file_path: Option<String>,
704}
705
706impl ModelParams {
707    pub async fn new(
708        mongodb_uri: &str,
709        db_name: &str,
710        save_to_db: bool,
711        file_path: Option<String>,
712    ) -> Self {
713        // Set up the DB client holder
714        let mut client_options = match ClientOptions::parse(mongodb_uri).await {
715            Ok(client_options) => client_options,
716            Err(e) => {
717                panic!("{:?}", e);
718            }
719        };
720        let tls_options = TlsOptions::builder().build();
721        client_options.tls = Some(Tls::Enabled(tls_options));
722        let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
723
724        ModelParams {
725            db_name: db_name.to_owned(),
726            client_holder,
727            collection_name: "model_params".to_owned(),
728            save_to_db,
729            file_path,
730        }
731    }
732
733    async fn get_db(&self) -> Option<Database> {
734        let db = match database::get(&self.client_holder, &self.db_name).await {
735            Ok(db) => Some(db),
736            Err(e) => {
737                log::error!("get_db: {:?}", e);
738                None
739            }
740        };
741        db
742    }
743
744    pub async fn save_model(
745        &self,
746        key: &str,
747        model: &SerializableModel,
748    ) -> Result<(), Box<dyn std::error::Error>> {
749        if self.save_to_db {
750            self.save_model_to_db(key, model).await
751        } else {
752            self.save_model_to_file(key, model).await
753        }
754    }
755
756    pub async fn load_model(
757        &self,
758        key: &str,
759    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
760        if self.save_to_db {
761            self.load_model_from_db(key).await
762        } else {
763            self.load_model_from_file(key).await
764        }
765    }
766
767    async fn save_model_to_db(
768        &self,
769        key: &str,
770        model: &SerializableModel,
771    ) -> Result<(), Box<dyn std::error::Error>> {
772        let db = self.get_db().await.ok_or("no db")?;
773        let collection: Collection<Document> = db.collection(&self.collection_name);
774        let serialized_model = bincode::serialize(model)?;
775
776        let document = doc! {
777            "key": key,
778            "model": Bson::Binary(mongodb::bson::Binary {
779                subtype: mongodb::bson::spec::BinarySubtype::Generic,
780                bytes: serialized_model
781            })
782        };
783
784        collection
785            .update_one(
786                doc! { "key": key },
787                doc! { "$set": document },
788                mongodb::options::UpdateOptions::builder()
789                    .upsert(true)
790                    .build(),
791            )
792            .await?;
793        Ok(())
794    }
795
796    async fn load_model_from_db(
797        &self,
798        key: &str,
799    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
800        let db = self.get_db().await.ok_or("no db")?;
801        let collection: Collection<Document> = db.collection(&self.collection_name);
802
803        let filter = doc! { "key": key };
804        let document = collection
805            .find_one(filter, None)
806            .await?
807            .ok_or("No model found in the collection")?;
808
809        if let Some(Bson::Binary(model_bytes)) = document.get("model") {
810            let model: SerializableModel = bincode::deserialize(&model_bytes.bytes)?;
811            Ok(model)
812        } else {
813            Err("Invalid data format".into())
814        }
815    }
816
817    async fn save_model_to_file(
818        &self,
819        key: &str,
820        model: &SerializableModel,
821    ) -> Result<(), Box<dyn std::error::Error>> {
822        let serialized_model = bincode::serialize(model)?;
823        let file_name = format!("{}.bin", key);
824
825        let file_path = if let Some(ref dir) = self.file_path {
826            Path::new(dir).join(file_name)
827        } else {
828            Path::new(&file_name).to_path_buf()
829        };
830
831        let mut file = File::create(&file_path)?;
832        file.write_all(&serialized_model)?;
833        Ok(())
834    }
835
836    async fn load_model_from_file(
837        &self,
838        key: &str,
839    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
840        let file_name = format!("{}.bin", key);
841
842        let file_path = if let Some(ref dir) = self.file_path {
843            Path::new(dir).join(file_name)
844        } else {
845            Path::new(&file_name).to_path_buf()
846        };
847
848        let mut file = File::open(&file_path).map_err(|e| {
849            log::error!("Failed to open file: {:?}: {}", file_path, e);
850            e
851        })?;
852        let mut buffer = Vec::new();
853        file.read_to_end(&mut buffer)?;
854        let model: SerializableModel = bincode::deserialize(&buffer)?;
855        Ok(model)
856    }
857}