1use bson::doc;
4use bson::Bson;
5use bson::Document;
6use debot_utils::get_local_time;
7use debot_utils::HasId;
8use mongodb::Collection;
9use mongodb::{
10 options::{ClientOptions, Tls, TlsOptions},
11 Database,
12};
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use shared_mongodb::{database, ClientHolder};
16use std::collections::HashMap;
17use std::error;
18use std::fs::File;
19use std::io::{Read, Write};
20use std::path::Path;
21use std::sync::Arc;
22use std::time::SystemTime;
23use tokio::sync::Mutex;
24
25use crate::delete_item_all;
26use crate::SearchMode;
27use crate::TradingStrategy;
28use crate::{
29 create_unique_index, insert_item, search_item, search_items, update_item, Counter, CounterType,
30 Entity,
31};
32
33async fn get_last_id<T: Default + Entity + HasId>(db: &Database) -> u32 {
34 let item = T::default();
35 match search_items(
36 db,
37 &item,
38 crate::SearchMode::Descending,
39 Some(1),
40 None,
41 Some("id"),
42 )
43 .await
44 {
45 Ok(mut items) => items.pop().and_then(|item| item.id()).unwrap_or(0),
46 Err(e) => {
47 log::info!("get_last_id: {:?}", e);
48 0
49 }
50 }
51}
52
53#[derive(Serialize, Deserialize, Clone, Debug)]
54pub enum SampleTerm {
55 TradingTerm,
56 ShortTerm,
57 LongTerm,
58}
59
60impl SampleTerm {
61 pub fn to_numeric(&self) -> Decimal {
62 match self {
63 SampleTerm::TradingTerm => Decimal::new(1, 0),
64 SampleTerm::ShortTerm => Decimal::new(2, 0),
65 SampleTerm::LongTerm => Decimal::new(3, 0),
66 }
67 }
68}
69
70#[derive(Serialize, Deserialize, Clone, Debug)]
71pub struct FundConfig {
72 pub token: String,
73 pub trading_strategy: TradingStrategy,
74 pub balance_per_strategy: Decimal,
75 pub risk_reward: Decimal,
76 pub take_profit_ratio: Option<Decimal>,
77 pub atr_spread: Option<Decimal>,
78 pub atr_term: SampleTerm,
79 pub open_minutes: i64,
80 pub flash_crash_atr_multiplier: Option<Decimal>,
81 pub order_size_multiplier: Decimal,
82 pub flash_crash_min_percentage: Option<Decimal>,
83 pub breakout_atr_multiplier: Decimal,
84}
85
86#[derive(Serialize, Deserialize, Clone, Debug)]
87pub struct AppState {
88 pub id: u32,
89 pub last_execution_time: Option<SystemTime>,
90 pub last_equity: Option<Decimal>,
91 pub ave_dd: Option<Decimal>,
92 pub max_dd: Option<Decimal>,
93 pub cumulative_return: Decimal,
94 pub cumulative_dd: Decimal,
95 pub score: Option<Decimal>,
96 pub score_2: Option<Decimal>,
97 pub score_3: Option<Decimal>,
98 pub curcuit_break: bool,
99 pub error_time: Vec<String>,
100 pub max_invested_amount: Decimal,
101 pub fund_configs: Option<Vec<FundConfig>>,
102}
103
104impl Default for AppState {
105 fn default() -> Self {
106 Self {
107 id: 1,
108 last_execution_time: None,
109 last_equity: None,
110 ave_dd: None,
111 max_dd: None,
112 cumulative_return: Decimal::ZERO,
113 cumulative_dd: Decimal::ZERO,
114 score: None,
115 score_2: None,
116 score_3: None,
117 curcuit_break: false,
118 error_time: vec![],
119 max_invested_amount: Decimal::ZERO,
120 fund_configs: Some(vec![]),
121 }
122 }
123}
124
125#[derive(Serialize, Deserialize, Clone, Debug, Default)]
126pub struct PnlLog {
127 pub id: Option<u32>,
128 pub date: String,
129 pub pnl: Decimal,
130}
131
132impl HasId for PnlLog {
133 fn id(&self) -> Option<u32> {
134 self.id
135 }
136}
137
138#[derive(Serialize, Deserialize, Debug, Clone, Default)]
139pub struct PricePoint {
140 pub timestamp: i64,
141 pub timestamp_str: String,
142 pub price: Decimal,
143 pub volume: Option<Decimal>,
144 pub num_trades: Option<u64>,
145 pub funding_rate: Option<Decimal>,
146 pub open_interest: Option<Decimal>,
147 pub oracle_price: Option<Decimal>,
148 pub debug: DebugLog,
149}
150
151impl PricePoint {
152 pub fn new(
153 price: Decimal,
154 timestamp: Option<i64>,
155 volume: Option<Decimal>,
156 num_trades: Option<u64>,
157 funding_rate: Option<Decimal>,
158 open_interest: Option<Decimal>,
159 oracle_price: Option<Decimal>,
160 debug: DebugLog,
161 ) -> Self {
162 let (local_timestamp, timestamp_str) = get_local_time();
163 let timestamp = timestamp.unwrap_or(local_timestamp);
164 Self {
165 timestamp,
166 timestamp_str,
167 price,
168 volume,
169 num_trades,
170 funding_rate,
171 open_interest,
172 oracle_price,
173 debug,
174 }
175 }
176}
177
178#[derive(Serialize, Deserialize, Clone, Debug, Default)]
179pub struct PriceLog {
180 pub id: Option<u32>,
181 pub name: String,
182 pub token_name: String,
183 pub price_point: PricePoint,
184}
185
186impl HasId for PriceLog {
187 fn id(&self) -> Option<u32> {
188 self.id
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
193pub enum CandlePattern {
194 #[default]
195 None,
196 Hammer,
197 InvertedHammer,
198 BullishEngulfing,
199 BearishEngulfing,
200 Doji,
201 Marubozu,
202 MorningStar,
203 EveningStar,
204 ThreeWhiteSoldiers,
205 ThreeBlackCrows,
206 PiercingPattern,
207 DarkCloudCover,
208 Harami,
209 HaramiCross,
210 SpinningTop,
211}
212
213impl CandlePattern {
214 pub fn to_one_hot(&self) -> [Decimal; 16] {
215 let mut one_hot = [Decimal::ZERO; 16];
216
217 match self {
218 CandlePattern::None => one_hot[0] = Decimal::ONE,
219 CandlePattern::Hammer => one_hot[1] = Decimal::ONE,
220 CandlePattern::InvertedHammer => one_hot[2] = Decimal::ONE,
221 CandlePattern::BullishEngulfing => one_hot[3] = Decimal::ONE,
222 CandlePattern::BearishEngulfing => one_hot[4] = Decimal::ONE,
223 CandlePattern::Doji => one_hot[5] = Decimal::ONE,
224 CandlePattern::Marubozu => one_hot[6] = Decimal::ONE,
225 CandlePattern::MorningStar => one_hot[7] = Decimal::ONE,
226 CandlePattern::EveningStar => one_hot[8] = Decimal::ONE,
227 CandlePattern::ThreeWhiteSoldiers => one_hot[9] = Decimal::ONE,
228 CandlePattern::ThreeBlackCrows => one_hot[10] = Decimal::ONE,
229 CandlePattern::PiercingPattern => one_hot[11] = Decimal::ONE,
230 CandlePattern::DarkCloudCover => one_hot[12] = Decimal::ONE,
231 CandlePattern::Harami => one_hot[13] = Decimal::ONE,
232 CandlePattern::HaramiCross => one_hot[14] = Decimal::ONE,
233 CandlePattern::SpinningTop => one_hot[15] = Decimal::ONE,
234 }
235
236 one_hot
237 }
238}
239
240#[derive(Serialize, Deserialize, Clone, Debug, Default)]
241pub struct DebugLog {
242 pub input_1: Decimal,
243 pub input_2: Decimal,
244 pub input_3: Decimal,
245 pub input_4: Decimal,
246 pub input_5: Decimal,
247 pub input_6: Decimal,
248 pub input_7: Decimal,
249 pub input_8: Decimal,
250 pub input_9: Decimal,
251 pub input_10: Decimal,
252 pub input_11: Decimal,
253 pub input_12: Decimal,
254 pub input_13: Decimal,
255 pub input_14: Decimal,
256 pub input_15: Decimal,
257 pub input_16: Decimal,
258 pub input_17: Decimal,
259 pub input_18: Decimal,
260 pub input_19: Decimal,
261 pub input_20: Decimal,
262 pub input_21: Decimal,
263 pub input_22: Decimal,
264 pub input_23: Decimal,
265 pub input_24: Decimal,
266 pub input_25: Decimal,
267 pub input_26: Decimal,
268 pub input_27: Decimal,
269 pub input_28: Decimal,
270 pub input_29: Decimal,
271 pub input_30: CandlePattern,
272 pub input_31: CandlePattern,
273 pub input_32: CandlePattern,
274 pub input_33: CandlePattern,
275 pub input_34: CandlePattern,
276 pub input_35: CandlePattern,
277 pub input_36: CandlePattern,
278 pub input_37: CandlePattern,
279 pub input_38: CandlePattern,
280 pub input_39: CandlePattern,
281 pub output_1: Decimal,
282 pub output_2: Decimal,
283 pub output_3: Option<Decimal>,
284 pub output_4: Option<Decimal>,
285 pub output_5: Option<Decimal>,
286}
287
288#[derive(Serialize, Deserialize, Clone, Debug, Default)]
289pub struct PositionLog {
290 pub id: Option<u32>,
291 pub fund_name: String,
292 pub order_id: String,
293 pub ordered_price: Decimal,
294 pub state: String,
295 pub token_name: String,
296 pub open_time_str: String,
297 pub open_timestamp: i64,
298 pub close_time_str: String,
299 pub average_open_price: Decimal,
300 pub position_type: String,
301 pub close_price: Decimal,
302 pub asset_in_usd: Decimal,
303 pub pnl: Decimal,
304 pub fee: Decimal,
305 pub debug: DebugLog,
306}
307
308#[derive(Serialize, Deserialize)]
309pub struct SerializableModel {
310 pub model: Vec<u8>,
311}
312
313impl HasId for PositionLog {
314 fn id(&self) -> Option<u32> {
315 self.id
316 }
317}
318
319pub struct TransactionLog {
320 counter: Counter,
321 db_r_name: String,
322 db_w_name: String,
323 client_holder: Arc<Mutex<ClientHolder>>,
324}
325
326impl TransactionLog {
327 pub async fn new(
328 max_position_counter: Option<u32>,
329 max_price_counter: Option<u32>,
330 max_pnl_counter: Option<u32>,
331 mongodb_uri: &str,
332 db_r_name: &str,
333 db_w_name: &str,
334 back_test: bool,
335 ) -> Self {
336 let mut client_options = match ClientOptions::parse(mongodb_uri).await {
338 Ok(client_options) => client_options,
339 Err(e) => {
340 panic!("{:?}", e);
341 }
342 };
343 let tls_options = TlsOptions::builder().build();
344 client_options.tls = Some(Tls::Enabled(tls_options));
345 let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
346
347 let db_w = shared_mongodb::database::get(&client_holder, &db_w_name)
349 .await
350 .unwrap();
351 let db_r = shared_mongodb::database::get(&client_holder, &db_r_name)
352 .await
353 .unwrap();
354
355 create_unique_index(&db_w)
357 .await
358 .expect("Error creating unique index in db_w");
359 create_unique_index(&db_r)
360 .await
361 .expect("Error creating unique index in db_r");
362
363 if back_test {
364 if let Err(e) = Self::delete_all_positions(&db_w).await {
365 panic!("delete_all_positions failed: {:?}", e);
366 }
367 if let Err(e) = Self::delete_app_state(&db_w).await {
368 panic!("delete_app_state failed: {:?}", e);
369 }
370 }
371
372 let last_position_counter =
373 TransactionLog::get_last_transaction_id(&db_w, CounterType::Position).await;
374 let last_price_counter =
375 TransactionLog::get_last_transaction_id(&db_w, CounterType::Price).await;
376 let last_pnl_counter =
377 TransactionLog::get_last_transaction_id(&db_w, CounterType::Pnl).await;
378
379 let counter = Counter::new(
380 max_position_counter,
381 max_price_counter,
382 max_pnl_counter,
383 last_position_counter,
384 last_price_counter,
385 last_pnl_counter,
386 );
387
388 log::warn!(
389 "position = {}/{:?}, price = {}/{:?}, pnl = {}/{:?}",
390 last_position_counter,
391 max_position_counter,
392 last_price_counter,
393 max_price_counter,
394 last_pnl_counter,
395 max_pnl_counter,
396 );
397
398 TransactionLog {
399 counter,
400 db_r_name: db_r_name.to_owned(),
401 db_w_name: db_w_name.to_owned(),
402 client_holder,
403 }
404 }
405
406 pub fn increment_counter(&self, counter_type: CounterType) -> u32 {
407 self.counter.increment(counter_type)
408 }
409
410 pub async fn get_last_transaction_id(db: &Database, counter_type: CounterType) -> u32 {
411 match counter_type {
412 CounterType::Position => get_last_id::<PositionLog>(db).await,
413 CounterType::Price => get_last_id::<PriceLog>(db).await,
414 CounterType::Pnl => get_last_id::<PnlLog>(db).await,
415 }
416 }
417
418 pub async fn get_w_db(&self) -> Option<Database> {
419 self.get_db(false).await
420 }
421
422 pub async fn get_r_db(&self) -> Option<Database> {
423 self.get_db(true).await
424 }
425
426 async fn get_db(&self, read: bool) -> Option<Database> {
427 let db_name = if read {
428 &self.db_r_name
429 } else {
430 &self.db_w_name
431 };
432 let db = match database::get(&self.client_holder, db_name).await {
433 Ok(db) => Some(db),
434 Err(e) => {
435 log::error!("get_db: {:?}", e);
436 None
437 }
438 };
439 db
440 }
441
442 pub async fn update_transaction(
443 db: &Database,
444 item: &PositionLog,
445 ) -> Result<(), Box<dyn error::Error>> {
446 update_item(db, item).await?;
447 Ok(())
448 }
449
450 pub async fn update_price(db: &Database, item: PriceLog) -> Result<(), Box<dyn error::Error>> {
451 update_item(db, &item).await?;
452 Ok(())
453 }
454
455 pub async fn copy_price(db_r: &Database, db_w: &Database, limit: Option<u32>) {
456 let item = PriceLog::default();
457 let items = {
458 match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
459 Ok(items) => items,
460 Err(e) => {
461 log::error!("get price: {:?}", e);
462 return;
463 }
464 }
465 };
466 log::debug!("get prices: num = {}", items.len());
467
468 for item in &items {
469 match insert_item(db_w, item).await {
470 Ok(_) => {}
471 Err(e) => {
472 log::error!("write price: {:?}", e);
473 return;
474 }
475 }
476 }
477 }
478
479 pub async fn copy_position(db_r: &Database, db_w: &Database, limit: Option<u32>) {
480 let item = PositionLog::default();
481 let items = {
482 match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
483 Ok(items) => items,
484 Err(e) => {
485 log::error!("get position: {:?}", e);
486 return;
487 }
488 }
489 };
490 log::debug!("get position: num = {}", items.len());
491
492 for item in &items {
493 match insert_item(db_w, item).await {
494 Ok(_) => {}
495 Err(e) => {
496 log::error!("write position: {:?}", e);
497 return;
498 }
499 }
500 }
501 }
502
503 pub async fn get_price_market_data(
504 db: &Database,
505 limit: Option<u32>,
506 id: Option<u32>,
507 is_ascend: bool,
508 ) -> HashMap<String, HashMap<String, Vec<PricePoint>>> {
509 let search_mode = if is_ascend {
510 SearchMode::Ascending
511 } else {
512 SearchMode::Descending
513 };
514 let sort_key = Some("price_point.timestamp");
515 let item = PriceLog::default();
516
517 let items = match id {
518 Some(id) => search_item(db, &item, Some(id), sort_key)
519 .await
520 .map(|item| vec![item]),
521 None => search_items(db, &item, search_mode, limit, None, sort_key).await,
522 };
523
524 let Ok(mut items) = items else {
525 log::warn!("get_price_market_data: search failed");
526 return HashMap::new();
527 };
528
529 items.sort_by_key(|p| p.price_point.timestamp);
530
531 let mut result = HashMap::new();
532 for price_log in items {
533 result
534 .entry(price_log.name)
535 .or_insert_with(HashMap::new)
536 .entry(price_log.token_name)
537 .or_insert_with(Vec::new)
538 .push(price_log.price_point);
539 }
540
541 result
542 }
543
544 pub async fn get_all_positions(
545 db: &Database,
546 limit: Option<u32>,
547 id: Option<u32>,
548 is_ascend: bool,
549 ) -> Vec<PositionLog> {
550 let search_mode = if is_ascend {
551 SearchMode::Ascending
552 } else {
553 SearchMode::Descending
554 };
555 let sort_key = Some("open_timestamp");
556 let item = PositionLog::default();
557
558 let items = if let Some(id) = id {
559 match search_item(db, &item, Some(id), sort_key).await {
560 Ok(position) => vec![position],
561 Err(e) => {
562 log::warn!("get_all_positions: {:?}", e);
563 vec![]
564 }
565 }
566 } else {
567 match search_items(db, &item, search_mode, limit, None, sort_key).await {
568 Ok(positions) => positions,
569 Err(e) => {
570 log::warn!("get_all_positions: {:?}", e);
571 vec![]
572 }
573 }
574 };
575
576 items
577 }
578
579 async fn delete_all_positions(db: &Database) -> Result<(), Box<dyn error::Error>> {
580 let item = PositionLog::default();
581 delete_item_all(db, &item).await
582 }
583
584 pub async fn insert_pnl(db: &Database, item: PnlLog) -> Result<(), Box<dyn error::Error>> {
585 insert_item(db, &item).await?;
586 Ok(())
587 }
588
589 pub async fn get_app_state(db: &Database) -> AppState {
590 let item = AppState::default();
591 match search_item(db, &item, Some(1), Some("id")).await {
592 Ok(item) => item,
593 Err(e) => {
594 log::warn!("get_app_state: {:?}", e);
595 item
596 }
597 }
598 }
599
600 async fn delete_app_state(db: &Database) -> Result<(), Box<dyn error::Error>> {
601 let item = AppState::default();
602 delete_item_all(db, &item).await
603 }
604
605 pub async fn update_app_state(
606 db: &Database,
607 last_execution_time: Option<SystemTime>,
608 last_equity: Option<Decimal>,
609 ave_dd: Option<Decimal>,
610 max_dd: Option<Decimal>,
611 cumulative_return: Option<Decimal>,
612 cumulative_dd: Option<Decimal>,
613 score: Option<Decimal>,
614 score_2: Option<Decimal>,
615 score_3: Option<Decimal>,
616 curcuit_break: bool,
617 error_time: Option<String>,
618 max_invested_amount: Option<Decimal>,
619 fund_configs: Option<Vec<FundConfig>>,
620 ) -> Result<(), Box<dyn error::Error>> {
621 let item = AppState::default();
622 let mut item = match search_item(db, &item, Some(1), Some("id")).await {
623 Ok(prev_item) => prev_item,
624 Err(_) => item,
625 };
626
627 if last_execution_time.is_some() {
628 item.last_execution_time = last_execution_time;
629 }
630
631 if let Some(last_equity) = last_equity {
632 item.last_equity = Some(last_equity.round());
633 }
634
635 if let Some(ave_dd) = ave_dd {
636 item.ave_dd = Some(ave_dd.round());
637 }
638
639 if let Some(max_dd_val) = max_dd {
640 if item
641 .max_dd
642 .map_or(true, |item_max_dd| max_dd_val > item_max_dd)
643 {
644 item.max_dd = Some(max_dd_val.round());
645 }
646 }
647
648 if let Some(cumulative_return) = cumulative_return {
649 item.cumulative_return += cumulative_return.round();
650 }
651
652 if let Some(cumulative_dd) = cumulative_dd {
653 item.cumulative_dd += cumulative_dd.round();
654 }
655
656 if score.is_some() {
657 item.score = score;
658 }
659
660 if score_2.is_some() {
661 item.score_2 = score_2;
662 }
663
664 if score_3.is_some() {
665 item.score_3 = score_3;
666 }
667
668 item.curcuit_break = curcuit_break;
669
670 if let Some(error_time) = error_time {
671 item.error_time.push(error_time);
672 }
673
674 if let Some(max_invested_amount) = max_invested_amount {
675 item.max_invested_amount = max_invested_amount.round();
676 }
677
678 if let Some(fund_configs) = fund_configs {
679 item.fund_configs = Some(fund_configs);
680 }
681
682 update_item(db, &item).await?;
683 Ok(())
684 }
685
686 pub fn db_w_name(&self) -> &str {
687 &self.db_w_name
688 }
689}
690
691#[derive(Clone)]
692pub struct ModelParams {
693 db_name: String,
694 client_holder: Arc<Mutex<ClientHolder>>,
695 collection_name: String,
696 save_to_db: bool,
697 file_path: Option<String>,
698}
699
700impl ModelParams {
701 pub async fn new(
702 mongodb_uri: &str,
703 db_name: &str,
704 save_to_db: bool,
705 file_path: Option<String>,
706 ) -> Self {
707 let mut client_options = match ClientOptions::parse(mongodb_uri).await {
709 Ok(client_options) => client_options,
710 Err(e) => {
711 panic!("{:?}", e);
712 }
713 };
714 let tls_options = TlsOptions::builder().build();
715 client_options.tls = Some(Tls::Enabled(tls_options));
716 let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
717
718 ModelParams {
719 db_name: db_name.to_owned(),
720 client_holder,
721 collection_name: "model_params".to_owned(),
722 save_to_db,
723 file_path,
724 }
725 }
726
727 async fn get_db(&self) -> Option<Database> {
728 let db = match database::get(&self.client_holder, &self.db_name).await {
729 Ok(db) => Some(db),
730 Err(e) => {
731 log::error!("get_db: {:?}", e);
732 None
733 }
734 };
735 db
736 }
737
738 pub async fn save_model(
739 &self,
740 key: &str,
741 model: &SerializableModel,
742 ) -> Result<(), Box<dyn std::error::Error>> {
743 if self.save_to_db {
744 self.save_model_to_db(key, model).await
745 } else {
746 self.save_model_to_file(key, model).await
747 }
748 }
749
750 pub async fn load_model(
751 &self,
752 key: &str,
753 ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
754 if self.save_to_db {
755 self.load_model_from_db(key).await
756 } else {
757 self.load_model_from_file(key).await
758 }
759 }
760
761 async fn save_model_to_db(
762 &self,
763 key: &str,
764 model: &SerializableModel,
765 ) -> Result<(), Box<dyn std::error::Error>> {
766 let db = self.get_db().await.ok_or("no db")?;
767 let collection: Collection<Document> = db.collection(&self.collection_name);
768 let serialized_model = bincode::serialize(model)?;
769
770 let document = doc! {
771 "key": key,
772 "model": Bson::Binary(mongodb::bson::Binary {
773 subtype: mongodb::bson::spec::BinarySubtype::Generic,
774 bytes: serialized_model
775 })
776 };
777
778 collection
779 .update_one(
780 doc! { "key": key },
781 doc! { "$set": document },
782 mongodb::options::UpdateOptions::builder()
783 .upsert(true)
784 .build(),
785 )
786 .await?;
787 Ok(())
788 }
789
790 async fn load_model_from_db(
791 &self,
792 key: &str,
793 ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
794 let db = self.get_db().await.ok_or("no db")?;
795 let collection: Collection<Document> = db.collection(&self.collection_name);
796
797 let filter = doc! { "key": key };
798 let document = collection
799 .find_one(filter, None)
800 .await?
801 .ok_or("No model found in the collection")?;
802
803 if let Some(Bson::Binary(model_bytes)) = document.get("model") {
804 let model: SerializableModel = bincode::deserialize(&model_bytes.bytes)?;
805 Ok(model)
806 } else {
807 Err("Invalid data format".into())
808 }
809 }
810
811 async fn save_model_to_file(
812 &self,
813 key: &str,
814 model: &SerializableModel,
815 ) -> Result<(), Box<dyn std::error::Error>> {
816 let serialized_model = bincode::serialize(model)?;
817 let file_name = format!("{}.bin", key);
818
819 let file_path = if let Some(ref dir) = self.file_path {
820 Path::new(dir).join(file_name)
821 } else {
822 Path::new(&file_name).to_path_buf()
823 };
824
825 let mut file = File::create(&file_path)?;
826 file.write_all(&serialized_model)?;
827 Ok(())
828 }
829
830 async fn load_model_from_file(
831 &self,
832 key: &str,
833 ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
834 let file_name = format!("{}.bin", key);
835
836 let file_path = if let Some(ref dir) = self.file_path {
837 Path::new(dir).join(file_name)
838 } else {
839 Path::new(&file_name).to_path_buf()
840 };
841
842 let mut file = File::open(&file_path).map_err(|e| {
843 log::error!("Failed to open file: {:?}: {}", file_path, e);
844 e
845 })?;
846 let mut buffer = Vec::new();
847 file.read_to_end(&mut buffer)?;
848 let model: SerializableModel = bincode::deserialize(&buffer)?;
849 Ok(model)
850 }
851}