debot_db/
transaction_log.rs

1// transaction_log.rs
2
3use bson::doc;
4use bson::Bson;
5use bson::Document;
6use debot_utils::get_local_time;
7use debot_utils::HasId;
8use mongodb::Collection;
9use mongodb::{
10    options::{ClientOptions, Tls, TlsOptions},
11    Database,
12};
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use shared_mongodb::{database, ClientHolder};
16use std::collections::HashMap;
17use std::error;
18use std::fs::File;
19use std::io::{Read, Write};
20use std::path::Path;
21use std::sync::Arc;
22use std::time::SystemTime;
23use tokio::sync::Mutex;
24
25use crate::delete_item_all;
26use crate::SearchMode;
27use crate::TradingStrategy;
28use crate::{
29    create_unique_index, insert_item, search_item, search_items, update_item, Counter, CounterType,
30    Entity,
31};
32
33async fn get_last_id<T: Default + Entity + HasId>(db: &Database) -> u32 {
34    let item = T::default();
35    match search_items(
36        db,
37        &item,
38        crate::SearchMode::Descending,
39        Some(1),
40        None,
41        Some("id"),
42    )
43    .await
44    {
45        Ok(mut items) => items.pop().and_then(|item| item.id()).unwrap_or(0),
46        Err(e) => {
47            log::info!("get_last_id: {:?}", e);
48            0
49        }
50    }
51}
52
53#[derive(Serialize, Deserialize, Clone, Debug)]
54pub enum SampleTerm {
55    TradingTerm,
56    ShortTerm,
57    LongTerm,
58}
59
60impl SampleTerm {
61    pub fn to_numeric(&self) -> Decimal {
62        match self {
63            SampleTerm::TradingTerm => Decimal::new(1, 0),
64            SampleTerm::ShortTerm => Decimal::new(2, 0),
65            SampleTerm::LongTerm => Decimal::new(3, 0),
66        }
67    }
68}
69
70#[derive(Serialize, Deserialize, Clone, Debug)]
71pub struct FundConfig {
72    pub token: String,
73    pub trading_strategy: TradingStrategy,
74    pub balance_per_strategy: Decimal,
75    pub risk_reward: Decimal,
76    pub take_profit_ratio: Option<Decimal>,
77    pub atr_spread: Option<Decimal>,
78    pub atr_term: SampleTerm,
79    pub open_minutes: i64,
80    pub flash_crash_atr_multiplier: Option<Decimal>,
81}
82
83#[derive(Serialize, Deserialize, Clone, Debug)]
84pub struct AppState {
85    pub id: u32,
86    pub last_execution_time: Option<SystemTime>,
87    pub last_equity: Option<Decimal>,
88    pub ave_dd: Option<Decimal>,
89    pub max_dd: Option<Decimal>,
90    pub cumulative_return: Decimal,
91    pub cumulative_dd: Decimal,
92    pub score: Option<Decimal>,
93    pub score_2: Option<Decimal>,
94    pub score_3: Option<Decimal>,
95    pub curcuit_break: bool,
96    pub error_time: Vec<String>,
97    pub max_invested_amount: Decimal,
98    pub fund_configs: Option<Vec<FundConfig>>,
99}
100
101impl Default for AppState {
102    fn default() -> Self {
103        Self {
104            id: 1,
105            last_execution_time: None,
106            last_equity: None,
107            ave_dd: None,
108            max_dd: None,
109            cumulative_return: Decimal::ZERO,
110            cumulative_dd: Decimal::ZERO,
111            score: None,
112            score_2: None,
113            score_3: None,
114            curcuit_break: false,
115            error_time: vec![],
116            max_invested_amount: Decimal::ZERO,
117            fund_configs: Some(vec![]),
118        }
119    }
120}
121
122#[derive(Serialize, Deserialize, Clone, Debug, Default)]
123pub struct PnlLog {
124    pub id: Option<u32>,
125    pub date: String,
126    pub pnl: Decimal,
127}
128
129impl HasId for PnlLog {
130    fn id(&self) -> Option<u32> {
131        self.id
132    }
133}
134
135#[derive(Serialize, Deserialize, Debug, Clone, Default)]
136pub struct PricePoint {
137    pub timestamp: i64,
138    pub timestamp_str: String,
139    pub price: Decimal,
140    pub volume: Option<Decimal>,
141    pub num_trades: Option<u64>,
142    pub funding_rate: Option<Decimal>,
143    pub open_interest: Option<Decimal>,
144    pub oracle_price: Option<Decimal>,
145}
146
147impl PricePoint {
148    pub fn new(
149        price: Decimal,
150        timestamp: Option<i64>,
151        volume: Option<Decimal>,
152        num_trades: Option<u64>,
153        funding_rate: Option<Decimal>,
154        open_interest: Option<Decimal>,
155        oracle_price: Option<Decimal>,
156    ) -> Self {
157        let (local_timestamp, timestamp_str) = get_local_time();
158        let timestamp = timestamp.unwrap_or(local_timestamp);
159        Self {
160            timestamp,
161            timestamp_str,
162            price,
163            volume,
164            num_trades,
165            funding_rate,
166            open_interest,
167            oracle_price,
168        }
169    }
170}
171
172#[derive(Serialize, Deserialize, Clone, Debug, Default)]
173pub struct PriceLog {
174    pub id: Option<u32>,
175    pub name: String,
176    pub token_name: String,
177    pub price_point: PricePoint,
178}
179
180impl HasId for PriceLog {
181    fn id(&self) -> Option<u32> {
182        self.id
183    }
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
187pub enum CandlePattern {
188    #[default]
189    None,
190    Hammer,
191    InvertedHammer,
192    BullishEngulfing,
193    BearishEngulfing,
194    Doji,
195    Marubozu,
196    MorningStar,
197    EveningStar,
198    ThreeWhiteSoldiers,
199    ThreeBlackCrows,
200    PiercingPattern,
201    DarkCloudCover,
202    Harami,
203    HaramiCross,
204    SpinningTop,
205}
206
207impl CandlePattern {
208    pub fn to_one_hot(&self) -> [Decimal; 16] {
209        let mut one_hot = [Decimal::ZERO; 16];
210
211        match self {
212            CandlePattern::None => one_hot[0] = Decimal::ONE,
213            CandlePattern::Hammer => one_hot[1] = Decimal::ONE,
214            CandlePattern::InvertedHammer => one_hot[2] = Decimal::ONE,
215            CandlePattern::BullishEngulfing => one_hot[3] = Decimal::ONE,
216            CandlePattern::BearishEngulfing => one_hot[4] = Decimal::ONE,
217            CandlePattern::Doji => one_hot[5] = Decimal::ONE,
218            CandlePattern::Marubozu => one_hot[6] = Decimal::ONE,
219            CandlePattern::MorningStar => one_hot[7] = Decimal::ONE,
220            CandlePattern::EveningStar => one_hot[8] = Decimal::ONE,
221            CandlePattern::ThreeWhiteSoldiers => one_hot[9] = Decimal::ONE,
222            CandlePattern::ThreeBlackCrows => one_hot[10] = Decimal::ONE,
223            CandlePattern::PiercingPattern => one_hot[11] = Decimal::ONE,
224            CandlePattern::DarkCloudCover => one_hot[12] = Decimal::ONE,
225            CandlePattern::Harami => one_hot[13] = Decimal::ONE,
226            CandlePattern::HaramiCross => one_hot[14] = Decimal::ONE,
227            CandlePattern::SpinningTop => one_hot[15] = Decimal::ONE,
228        }
229
230        one_hot
231    }
232}
233
234#[derive(Serialize, Deserialize, Clone, Debug, Default)]
235pub struct DebugLog {
236    pub input_1: Decimal,
237    pub input_2: Decimal,
238    pub input_3: Decimal,
239    pub input_4: Decimal,
240    pub input_5: Decimal,
241    pub input_6: Decimal,
242    pub input_7: Decimal,
243    pub input_8: Decimal,
244    pub input_9: Decimal,
245    pub input_10: Decimal,
246    pub input_11: Decimal,
247    pub input_12: Decimal,
248    pub input_13: Decimal,
249    pub input_14: Decimal,
250    pub input_15: Decimal,
251    pub input_16: Decimal,
252    pub input_17: Decimal,
253    pub input_18: Decimal,
254    pub input_19: Decimal,
255    pub input_20: Decimal,
256    pub input_21: Decimal,
257    pub input_22: Decimal,
258    pub input_23: Decimal,
259    pub input_24: Decimal,
260    pub input_25: Decimal,
261    pub input_26: Decimal,
262    pub input_27: Decimal,
263    pub input_28: Decimal,
264    pub input_29: Decimal,
265    pub input_30: CandlePattern,
266    pub input_31: CandlePattern,
267    pub input_32: CandlePattern,
268    pub input_33: CandlePattern,
269    pub input_34: CandlePattern,
270    pub input_35: CandlePattern,
271    pub input_36: CandlePattern,
272    pub input_37: CandlePattern,
273    pub input_38: CandlePattern,
274    pub input_39: CandlePattern,
275    pub output_1: Decimal,
276    pub output_2: Decimal,
277    pub output_3: Option<Decimal>,
278    pub output_4: Option<Decimal>,
279    pub output_5: Option<Decimal>,
280}
281
282#[derive(Serialize, Deserialize, Clone, Debug, Default)]
283pub struct PositionLog {
284    pub id: Option<u32>,
285    pub fund_name: String,
286    pub order_id: String,
287    pub ordered_price: Decimal,
288    pub state: String,
289    pub token_name: String,
290    pub open_time_str: String,
291    pub open_timestamp: i64,
292    pub close_time_str: String,
293    pub average_open_price: Decimal,
294    pub position_type: String,
295    pub close_price: Decimal,
296    pub asset_in_usd: Decimal,
297    pub pnl: Decimal,
298    pub fee: Decimal,
299    pub debug: DebugLog,
300}
301
302#[derive(Serialize, Deserialize)]
303pub struct SerializableModel {
304    pub model: Vec<u8>,
305}
306
307impl HasId for PositionLog {
308    fn id(&self) -> Option<u32> {
309        self.id
310    }
311}
312
313pub struct TransactionLog {
314    counter: Counter,
315    db_r_name: String,
316    db_w_name: String,
317    client_holder: Arc<Mutex<ClientHolder>>,
318}
319
320impl TransactionLog {
321    pub async fn new(
322        max_position_counter: Option<u32>,
323        max_price_counter: Option<u32>,
324        max_pnl_counter: Option<u32>,
325        mongodb_uri: &str,
326        db_r_name: &str,
327        db_w_name: &str,
328        back_test: bool,
329    ) -> Self {
330        // Set up the DB client holder
331        let mut client_options = match ClientOptions::parse(mongodb_uri).await {
332            Ok(client_options) => client_options,
333            Err(e) => {
334                panic!("{:?}", e);
335            }
336        };
337        let tls_options = TlsOptions::builder().build();
338        client_options.tls = Some(Tls::Enabled(tls_options));
339        let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
340
341        let db = shared_mongodb::database::get(&client_holder, &db_w_name)
342            .await
343            .unwrap();
344
345        create_unique_index(&db)
346            .await
347            .expect("Error creating unique index");
348
349        if back_test {
350            if let Err(e) = Self::delete_all_positions(&db).await {
351                panic!("delete_all_positions failed: {:?}", e);
352            }
353            if let Err(e) = Self::delete_app_state(&db).await {
354                panic!("delete_app_state failed: {:?}", e);
355            }
356        }
357
358        let last_position_counter =
359            TransactionLog::get_last_transaction_id(&db, CounterType::Position).await;
360        let last_price_counter =
361            TransactionLog::get_last_transaction_id(&db, CounterType::Price).await;
362        let last_pnl_counter = TransactionLog::get_last_transaction_id(&db, CounterType::Pnl).await;
363
364        let counter = Counter::new(
365            max_position_counter,
366            max_price_counter,
367            max_pnl_counter,
368            last_position_counter,
369            last_price_counter,
370            last_pnl_counter,
371        );
372
373        log::warn!(
374            "position = {}/{:?}, price = {}/{:?}, pnl = {}/{:?}",
375            last_position_counter,
376            max_position_counter,
377            last_price_counter,
378            max_price_counter,
379            last_pnl_counter,
380            max_pnl_counter,
381        );
382
383        TransactionLog {
384            counter,
385            db_r_name: db_r_name.to_owned(),
386            db_w_name: db_w_name.to_owned(),
387            client_holder,
388        }
389    }
390
391    pub fn increment_counter(&self, counter_type: CounterType) -> u32 {
392        self.counter.increment(counter_type)
393    }
394
395    pub async fn get_last_transaction_id(db: &Database, counter_type: CounterType) -> u32 {
396        match counter_type {
397            CounterType::Position => get_last_id::<PositionLog>(db).await,
398            CounterType::Price => get_last_id::<PriceLog>(db).await,
399            CounterType::Pnl => get_last_id::<PnlLog>(db).await,
400        }
401    }
402
403    pub async fn get_w_db(&self) -> Option<Database> {
404        self.get_db(false).await
405    }
406
407    pub async fn get_r_db(&self) -> Option<Database> {
408        self.get_db(true).await
409    }
410
411    async fn get_db(&self, read: bool) -> Option<Database> {
412        let db_name = if read {
413            &self.db_r_name
414        } else {
415            &self.db_w_name
416        };
417        let db = match database::get(&self.client_holder, db_name).await {
418            Ok(db) => Some(db),
419            Err(e) => {
420                log::error!("get_db: {:?}", e);
421                None
422            }
423        };
424        db
425    }
426
427    pub async fn update_transaction(
428        db: &Database,
429        item: &PositionLog,
430    ) -> Result<(), Box<dyn error::Error>> {
431        update_item(db, item).await?;
432        Ok(())
433    }
434
435    pub async fn update_price(db: &Database, item: PriceLog) -> Result<(), Box<dyn error::Error>> {
436        update_item(db, &item).await?;
437        Ok(())
438    }
439
440    pub async fn copy_price(db_r: &Database, db_w: &Database, limit: Option<u32>) {
441        let item = PriceLog::default();
442        let items = {
443            match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
444                Ok(items) => items,
445                Err(e) => {
446                    log::error!("get price: {:?}", e);
447                    return;
448                }
449            }
450        };
451        log::debug!("get prices: num = {}", items.len());
452
453        for item in &items {
454            match insert_item(db_w, item).await {
455                Ok(_) => {}
456                Err(e) => {
457                    log::error!("write price: {:?}", e);
458                    return;
459                }
460            }
461        }
462    }
463
464    pub async fn copy_position(db_r: &Database, db_w: &Database, limit: Option<u32>) {
465        let item = PositionLog::default();
466        let items = {
467            match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
468                Ok(items) => items,
469                Err(e) => {
470                    log::error!("get position: {:?}", e);
471                    return;
472                }
473            }
474        };
475        log::debug!("get position: num = {}", items.len());
476
477        for item in &items {
478            match insert_item(db_w, item).await {
479                Ok(_) => {}
480                Err(e) => {
481                    log::error!("write position: {:?}", e);
482                    return;
483                }
484            }
485        }
486    }
487
488    pub async fn get_price_market_data(
489        db: &Database,
490        limit: Option<u32>,
491        id: Option<u32>,
492        is_ascend: bool,
493    ) -> HashMap<String, HashMap<String, Vec<PricePoint>>> {
494        let search_mode = if is_ascend {
495            SearchMode::Ascending
496        } else {
497            SearchMode::Descending
498        };
499        let sort_key = Some("price_point.timestamp");
500        let item = PriceLog::default();
501
502        let items = match id {
503            Some(id) => search_item(db, &item, Some(id), sort_key)
504                .await
505                .map(|item| vec![item]),
506            None => search_items(db, &item, search_mode, limit, None, sort_key).await,
507        };
508
509        let Ok(mut items) = items else {
510            log::warn!("get_price_market_data: search failed");
511            return HashMap::new();
512        };
513
514        items.sort_by_key(|p| p.price_point.timestamp);
515
516        let mut result = HashMap::new();
517        for price_log in items {
518            result
519                .entry(price_log.name)
520                .or_insert_with(HashMap::new)
521                .entry(price_log.token_name)
522                .or_insert_with(Vec::new)
523                .push(price_log.price_point);
524        }
525
526        result
527    }
528
529    pub async fn get_all_positions(
530        db: &Database,
531        limit: Option<u32>,
532        id: Option<u32>,
533        is_ascend: bool,
534    ) -> Vec<PositionLog> {
535        let search_mode = if is_ascend {
536            SearchMode::Ascending
537        } else {
538            SearchMode::Descending
539        };
540        let sort_key = Some("open_timestamp");
541        let item = PositionLog::default();
542
543        let items = if let Some(id) = id {
544            match search_item(db, &item, Some(id), sort_key).await {
545                Ok(position) => vec![position],
546                Err(e) => {
547                    log::warn!("get_all_positions: {:?}", e);
548                    vec![]
549                }
550            }
551        } else {
552            match search_items(db, &item, search_mode, limit, None, sort_key).await {
553                Ok(positions) => positions,
554                Err(e) => {
555                    log::warn!("get_all_positions: {:?}", e);
556                    vec![]
557                }
558            }
559        };
560
561        items
562    }
563
564    async fn delete_all_positions(db: &Database) -> Result<(), Box<dyn error::Error>> {
565        let item = PositionLog::default();
566        delete_item_all(db, &item).await
567    }
568
569    pub async fn insert_pnl(db: &Database, item: PnlLog) -> Result<(), Box<dyn error::Error>> {
570        insert_item(db, &item).await?;
571        Ok(())
572    }
573
574    pub async fn get_app_state(db: &Database) -> AppState {
575        let item = AppState::default();
576        match search_item(db, &item, Some(1), Some("id")).await {
577            Ok(item) => item,
578            Err(e) => {
579                log::warn!("get_app_state: {:?}", e);
580                item
581            }
582        }
583    }
584
585    async fn delete_app_state(db: &Database) -> Result<(), Box<dyn error::Error>> {
586        let item = AppState::default();
587        delete_item_all(db, &item).await
588    }
589
590    pub async fn update_app_state(
591        db: &Database,
592        last_execution_time: Option<SystemTime>,
593        last_equity: Option<Decimal>,
594        ave_dd: Option<Decimal>,
595        max_dd: Option<Decimal>,
596        cumulative_return: Option<Decimal>,
597        cumulative_dd: Option<Decimal>,
598        score: Option<Decimal>,
599        score_2: Option<Decimal>,
600        score_3: Option<Decimal>,
601        curcuit_break: bool,
602        error_time: Option<String>,
603        max_invested_amount: Option<Decimal>,
604        fund_configs: Option<Vec<FundConfig>>,
605    ) -> Result<(), Box<dyn error::Error>> {
606        let item = AppState::default();
607        let mut item = match search_item(db, &item, Some(1), Some("id")).await {
608            Ok(prev_item) => prev_item,
609            Err(_) => item,
610        };
611
612        if last_execution_time.is_some() {
613            item.last_execution_time = last_execution_time;
614        }
615
616        if let Some(last_equity) = last_equity {
617            item.last_equity = Some(last_equity.round());
618        }
619
620        if let Some(ave_dd) = ave_dd {
621            item.ave_dd = Some(ave_dd.round());
622        }
623
624        if let Some(max_dd_val) = max_dd {
625            if item
626                .max_dd
627                .map_or(true, |item_max_dd| max_dd_val > item_max_dd)
628            {
629                item.max_dd = Some(max_dd_val.round());
630            }
631        }
632
633        if let Some(cumulative_return) = cumulative_return {
634            item.cumulative_return += cumulative_return.round();
635        }
636
637        if let Some(cumulative_dd) = cumulative_dd {
638            item.cumulative_dd += cumulative_dd.round();
639        }
640
641        if score.is_some() {
642            item.score = score;
643        }
644
645        if score_2.is_some() {
646            item.score_2 = score_2;
647        }
648
649        if score_3.is_some() {
650            item.score_3 = score_3;
651        }
652
653        item.curcuit_break = curcuit_break;
654
655        if let Some(error_time) = error_time {
656            item.error_time.push(error_time);
657        }
658
659        if let Some(max_invested_amount) = max_invested_amount {
660            item.max_invested_amount = max_invested_amount.round();
661        }
662
663        if let Some(fund_configs) = fund_configs {
664            item.fund_configs = Some(fund_configs);
665        }
666
667        update_item(db, &item).await?;
668        Ok(())
669    }
670
671    pub fn db_w_name(&self) -> &str {
672        &self.db_w_name
673    }
674}
675
676#[derive(Clone)]
677pub struct ModelParams {
678    db_name: String,
679    client_holder: Arc<Mutex<ClientHolder>>,
680    collection_name: String,
681    save_to_db: bool,
682    file_path: Option<String>,
683}
684
685impl ModelParams {
686    pub async fn new(
687        mongodb_uri: &str,
688        db_name: &str,
689        save_to_db: bool,
690        file_path: Option<String>,
691    ) -> Self {
692        // Set up the DB client holder
693        let mut client_options = match ClientOptions::parse(mongodb_uri).await {
694            Ok(client_options) => client_options,
695            Err(e) => {
696                panic!("{:?}", e);
697            }
698        };
699        let tls_options = TlsOptions::builder().build();
700        client_options.tls = Some(Tls::Enabled(tls_options));
701        let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
702
703        ModelParams {
704            db_name: db_name.to_owned(),
705            client_holder,
706            collection_name: "model_params".to_owned(),
707            save_to_db,
708            file_path,
709        }
710    }
711
712    async fn get_db(&self) -> Option<Database> {
713        let db = match database::get(&self.client_holder, &self.db_name).await {
714            Ok(db) => Some(db),
715            Err(e) => {
716                log::error!("get_db: {:?}", e);
717                None
718            }
719        };
720        db
721    }
722
723    pub async fn save_model(
724        &self,
725        key: &str,
726        model: &SerializableModel,
727    ) -> Result<(), Box<dyn std::error::Error>> {
728        if self.save_to_db {
729            self.save_model_to_db(key, model).await
730        } else {
731            self.save_model_to_file(key, model).await
732        }
733    }
734
735    pub async fn load_model(
736        &self,
737        key: &str,
738    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
739        if self.save_to_db {
740            self.load_model_from_db(key).await
741        } else {
742            self.load_model_from_file(key).await
743        }
744    }
745
746    async fn save_model_to_db(
747        &self,
748        key: &str,
749        model: &SerializableModel,
750    ) -> Result<(), Box<dyn std::error::Error>> {
751        let db = self.get_db().await.ok_or("no db")?;
752        let collection: Collection<Document> = db.collection(&self.collection_name);
753        let serialized_model = bincode::serialize(model)?;
754
755        let document = doc! {
756            "key": key,
757            "model": Bson::Binary(mongodb::bson::Binary {
758                subtype: mongodb::bson::spec::BinarySubtype::Generic,
759                bytes: serialized_model
760            })
761        };
762
763        collection
764            .update_one(
765                doc! { "key": key },
766                doc! { "$set": document },
767                mongodb::options::UpdateOptions::builder()
768                    .upsert(true)
769                    .build(),
770            )
771            .await?;
772        Ok(())
773    }
774
775    async fn load_model_from_db(
776        &self,
777        key: &str,
778    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
779        let db = self.get_db().await.ok_or("no db")?;
780        let collection: Collection<Document> = db.collection(&self.collection_name);
781
782        let filter = doc! { "key": key };
783        let document = collection
784            .find_one(filter, None)
785            .await?
786            .ok_or("No model found in the collection")?;
787
788        if let Some(Bson::Binary(model_bytes)) = document.get("model") {
789            let model: SerializableModel = bincode::deserialize(&model_bytes.bytes)?;
790            Ok(model)
791        } else {
792            Err("Invalid data format".into())
793        }
794    }
795
796    async fn save_model_to_file(
797        &self,
798        key: &str,
799        model: &SerializableModel,
800    ) -> Result<(), Box<dyn std::error::Error>> {
801        let serialized_model = bincode::serialize(model)?;
802        let file_name = format!("{}.bin", key);
803
804        let file_path = if let Some(ref dir) = self.file_path {
805            Path::new(dir).join(file_name)
806        } else {
807            Path::new(&file_name).to_path_buf()
808        };
809
810        let mut file = File::create(&file_path)?;
811        file.write_all(&serialized_model)?;
812        Ok(())
813    }
814
815    async fn load_model_from_file(
816        &self,
817        key: &str,
818    ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
819        let file_name = format!("{}.bin", key);
820
821        let file_path = if let Some(ref dir) = self.file_path {
822            Path::new(dir).join(file_name)
823        } else {
824            Path::new(&file_name).to_path_buf()
825        };
826
827        let mut file = File::open(&file_path)?;
828        let mut buffer = Vec::new();
829        file.read_to_end(&mut buffer)?;
830        let model: SerializableModel = bincode::deserialize(&buffer)?;
831        Ok(model)
832    }
833}