1use bson::doc;
4use bson::Bson;
5use bson::Document;
6use debot_utils::get_local_time;
7use debot_utils::HasId;
8use mongodb::Collection;
9use mongodb::{
10 options::{ClientOptions, Tls, TlsOptions},
11 Database,
12};
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use shared_mongodb::{database, ClientHolder};
16use std::collections::HashMap;
17use std::error;
18use std::fs::File;
19use std::io::{Read, Write};
20use std::path::Path;
21use std::sync::Arc;
22use std::time::SystemTime;
23use tokio::sync::Mutex;
24
25use crate::delete_item_all;
26use crate::SearchMode;
27use crate::TradingStrategy;
28use crate::{
29 create_unique_index, insert_item, search_item, search_items, update_item, Counter, CounterType,
30 Entity,
31};
32
33fn default_input_40() -> Decimal {
34 Decimal::ONE
35}
36
37async fn get_last_id<T: Default + Entity + HasId>(db: &Database) -> u32 {
38 let item = T::default();
39 match search_items(
40 db,
41 &item,
42 crate::SearchMode::Descending,
43 Some(1),
44 None,
45 Some("id"),
46 )
47 .await
48 {
49 Ok(mut items) => items.pop().and_then(|item| item.id()).unwrap_or(0),
50 Err(e) => {
51 log::info!("get_last_id: {:?}", e);
52 0
53 }
54 }
55}
56
57#[derive(Serialize, Deserialize, Clone, Debug)]
58pub enum SampleTerm {
59 TradingTerm,
60 ShortTerm,
61 LongTerm,
62}
63
64impl SampleTerm {
65 pub fn to_numeric(&self) -> Decimal {
66 match self {
67 SampleTerm::TradingTerm => Decimal::new(1, 0),
68 SampleTerm::ShortTerm => Decimal::new(2, 0),
69 SampleTerm::LongTerm => Decimal::new(3, 0),
70 }
71 }
72}
73
74#[derive(Serialize, Deserialize, Clone, Debug)]
75pub struct FundConfig {
76 pub token: String,
77 pub trading_strategy: TradingStrategy,
78 pub balance_per_strategy: Decimal,
79 pub risk_reward: Decimal,
80 pub take_profit_ratio: Option<Decimal>,
81 pub atr_spread: Decimal,
82 pub atr_term: SampleTerm,
83 pub entry_timeout_sec: i64,
84 pub max_holding_sec: i64,
85 pub order_size_multiplier: Decimal,
86 pub tick_spread: i64,
87 pub bias_ticks: i64,
88}
89
90#[derive(Serialize, Deserialize, Clone, Debug)]
91pub struct AppState {
92 pub id: u32,
93 pub last_execution_time: Option<SystemTime>,
94 pub last_equity: Option<Decimal>,
95 pub ave_dd: Option<Decimal>,
96 pub max_dd: Option<Decimal>,
97 pub cumulative_return: Decimal,
98 pub cumulative_dd: Decimal,
99 pub score: Option<Decimal>,
100 pub score_2: Option<Decimal>,
101 pub score_3: Option<Decimal>,
102 pub curcuit_break: bool,
103 pub error_time: Vec<String>,
104 pub max_invested_amount: Decimal,
105 pub fund_configs: Option<Vec<FundConfig>>,
106}
107
108impl Default for AppState {
109 fn default() -> Self {
110 Self {
111 id: 1,
112 last_execution_time: None,
113 last_equity: None,
114 ave_dd: None,
115 max_dd: None,
116 cumulative_return: Decimal::ZERO,
117 cumulative_dd: Decimal::ZERO,
118 score: None,
119 score_2: None,
120 score_3: None,
121 curcuit_break: false,
122 error_time: vec![],
123 max_invested_amount: Decimal::ZERO,
124 fund_configs: Some(vec![]),
125 }
126 }
127}
128
129#[derive(Serialize, Deserialize, Clone, Debug, Default)]
130pub struct PnlLog {
131 pub id: Option<u32>,
132 pub date: String,
133 pub pnl: Decimal,
134}
135
136impl HasId for PnlLog {
137 fn id(&self) -> Option<u32> {
138 self.id
139 }
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone, Default)]
143pub struct PricePoint {
144 pub timestamp: i64,
145 pub timestamp_str: String,
146 pub price: Decimal,
147 pub volume: Option<Decimal>,
148 pub num_trades: Option<u64>,
149 pub funding_rate: Option<Decimal>,
150 pub open_interest: Option<Decimal>,
151 pub oracle_price: Option<Decimal>,
152 pub debug: Option<DebugLog>,
153}
154
155impl PricePoint {
156 pub fn new(
157 price: Decimal,
158 timestamp: Option<i64>,
159 volume: Option<Decimal>,
160 num_trades: Option<u64>,
161 funding_rate: Option<Decimal>,
162 open_interest: Option<Decimal>,
163 oracle_price: Option<Decimal>,
164 debug: Option<DebugLog>,
165 ) -> Self {
166 let (local_timestamp, timestamp_str) = get_local_time();
167 let timestamp = timestamp.unwrap_or(local_timestamp);
168 Self {
169 timestamp,
170 timestamp_str,
171 price,
172 volume,
173 num_trades,
174 funding_rate,
175 open_interest,
176 oracle_price,
177 debug,
178 }
179 }
180}
181
182#[derive(Serialize, Deserialize, Clone, Debug, Default)]
183pub struct PriceLog {
184 pub id: Option<u32>,
185 pub name: String,
186 pub token_name: String,
187 pub price_point: PricePoint,
188}
189
190impl HasId for PriceLog {
191 fn id(&self) -> Option<u32> {
192 self.id
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
197pub enum CandlePattern {
198 #[default]
199 None,
200 Hammer,
201 InvertedHammer,
202 BullishEngulfing,
203 BearishEngulfing,
204 Doji,
205 Marubozu,
206 MorningStar,
207 EveningStar,
208 ThreeWhiteSoldiers,
209 ThreeBlackCrows,
210 PiercingPattern,
211 DarkCloudCover,
212 Harami,
213 HaramiCross,
214 SpinningTop,
215}
216
217impl CandlePattern {
218 pub fn to_one_hot(&self) -> [Decimal; 16] {
219 let mut one_hot = [Decimal::ZERO; 16];
220
221 match self {
222 CandlePattern::None => one_hot[0] = Decimal::ONE,
223 CandlePattern::Hammer => one_hot[1] = Decimal::ONE,
224 CandlePattern::InvertedHammer => one_hot[2] = Decimal::ONE,
225 CandlePattern::BullishEngulfing => one_hot[3] = Decimal::ONE,
226 CandlePattern::BearishEngulfing => one_hot[4] = Decimal::ONE,
227 CandlePattern::Doji => one_hot[5] = Decimal::ONE,
228 CandlePattern::Marubozu => one_hot[6] = Decimal::ONE,
229 CandlePattern::MorningStar => one_hot[7] = Decimal::ONE,
230 CandlePattern::EveningStar => one_hot[8] = Decimal::ONE,
231 CandlePattern::ThreeWhiteSoldiers => one_hot[9] = Decimal::ONE,
232 CandlePattern::ThreeBlackCrows => one_hot[10] = Decimal::ONE,
233 CandlePattern::PiercingPattern => one_hot[11] = Decimal::ONE,
234 CandlePattern::DarkCloudCover => one_hot[12] = Decimal::ONE,
235 CandlePattern::Harami => one_hot[13] = Decimal::ONE,
236 CandlePattern::HaramiCross => one_hot[14] = Decimal::ONE,
237 CandlePattern::SpinningTop => one_hot[15] = Decimal::ONE,
238 }
239
240 one_hot
241 }
242}
243
244#[derive(Serialize, Deserialize, Clone, Debug, Default)]
245pub struct DebugLog {
246 pub input_1: Decimal,
247 pub input_2: Decimal,
248 pub input_3: Decimal,
249 pub input_4: Decimal,
250 pub input_5: Decimal,
251 pub input_6: Decimal,
252 pub input_7: Decimal,
253 pub input_8: Decimal,
254 pub input_9: Decimal,
255 pub input_10: Decimal,
256 pub input_11: Decimal,
257 pub input_12: Decimal,
258 pub input_13: Decimal,
259 pub input_14: Decimal,
260 pub input_15: Decimal,
261 pub input_16: Decimal,
262 pub input_17: Decimal,
263 pub input_18: Decimal,
264 pub input_19: Decimal,
265 pub input_20: Decimal,
266 pub input_21: Decimal,
267 pub input_22: Decimal,
268 pub input_23: Decimal,
269 pub input_24: Decimal,
270 pub input_25: Decimal,
271 pub input_26: Decimal,
272 pub input_27: Decimal,
273 pub input_28: Decimal,
274 pub input_29: Decimal,
275 pub input_30: CandlePattern,
276 pub input_31: CandlePattern,
277 pub input_32: CandlePattern,
278 pub input_33: CandlePattern,
279 pub input_34: CandlePattern,
280 pub input_35: CandlePattern,
281 pub input_36: CandlePattern,
282 pub input_37: CandlePattern,
283 pub input_38: CandlePattern,
284 pub input_39: CandlePattern,
285 #[serde(default = "default_input_40")]
286 pub input_40: Decimal, pub output_1: Decimal,
288 pub output_2: Decimal,
289 pub output_3: Decimal,
290 pub output_4: Decimal,
291 pub output_5: Decimal,
292}
293
294#[derive(Serialize, Deserialize, Clone, Debug, Default)]
295pub struct PositionLog {
296 pub id: Option<u32>,
297 pub fund_name: String,
298 pub order_id: String,
299 pub ordered_price: Decimal,
300 pub state: String,
301 pub token_name: String,
302 pub open_time_str: String,
303 pub open_timestamp: i64,
304 pub close_time_str: String,
305 pub average_open_price: Decimal,
306 pub position_type: String,
307 pub close_price: Decimal,
308 pub asset_in_usd: Decimal,
309 pub pnl: Decimal,
310 pub fee: Decimal,
311 pub debug: DebugLog,
312}
313
314#[derive(Serialize, Deserialize)]
315pub struct SerializableModel {
316 pub model: Vec<u8>,
317}
318
319impl HasId for PositionLog {
320 fn id(&self) -> Option<u32> {
321 self.id
322 }
323}
324
325pub struct TransactionLog {
326 counter: Counter,
327 db_r_name: String,
328 db_w_name: String,
329 client_holder: Arc<Mutex<ClientHolder>>,
330}
331
332impl TransactionLog {
333 pub async fn new(
334 max_position_counter: Option<u32>,
335 max_price_counter: Option<u32>,
336 max_pnl_counter: Option<u32>,
337 mongodb_uri: &str,
338 db_r_name: &str,
339 db_w_name: &str,
340 back_test: bool,
341 ) -> Self {
342 let mut client_options = match ClientOptions::parse(mongodb_uri).await {
344 Ok(client_options) => client_options,
345 Err(e) => {
346 panic!("{:?}", e);
347 }
348 };
349 let tls_options = TlsOptions::builder().build();
350 client_options.tls = Some(Tls::Enabled(tls_options));
351 let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
352
353 let db_w = shared_mongodb::database::get(&client_holder, &db_w_name)
355 .await
356 .unwrap();
357 let db_r = shared_mongodb::database::get(&client_holder, &db_r_name)
358 .await
359 .unwrap();
360
361 create_unique_index(&db_w)
363 .await
364 .expect("Error creating unique index in db_w");
365 create_unique_index(&db_r)
366 .await
367 .expect("Error creating unique index in db_r");
368
369 if back_test {
370 if let Err(e) = Self::delete_all_positions(&db_w).await {
371 panic!("delete_all_positions failed: {:?}", e);
372 }
373 if let Err(e) = Self::delete_app_state(&db_w).await {
374 panic!("delete_app_state failed: {:?}", e);
375 }
376 }
377
378 let last_position_counter =
379 TransactionLog::get_last_transaction_id(&db_w, CounterType::Position).await;
380 let last_price_counter =
381 TransactionLog::get_last_transaction_id(&db_w, CounterType::Price).await;
382 let last_pnl_counter =
383 TransactionLog::get_last_transaction_id(&db_w, CounterType::Pnl).await;
384
385 let counter = Counter::new(
386 max_position_counter,
387 max_price_counter,
388 max_pnl_counter,
389 last_position_counter,
390 last_price_counter,
391 last_pnl_counter,
392 );
393
394 log::warn!(
395 "position = {}/{:?}, price = {}/{:?}, pnl = {}/{:?}",
396 last_position_counter,
397 max_position_counter,
398 last_price_counter,
399 max_price_counter,
400 last_pnl_counter,
401 max_pnl_counter,
402 );
403
404 TransactionLog {
405 counter,
406 db_r_name: db_r_name.to_owned(),
407 db_w_name: db_w_name.to_owned(),
408 client_holder,
409 }
410 }
411
412 pub fn increment_counter(&self, counter_type: CounterType) -> u32 {
413 self.counter.increment(counter_type)
414 }
415
416 pub async fn get_last_transaction_id(db: &Database, counter_type: CounterType) -> u32 {
417 match counter_type {
418 CounterType::Position => get_last_id::<PositionLog>(db).await,
419 CounterType::Price => get_last_id::<PriceLog>(db).await,
420 CounterType::Pnl => get_last_id::<PnlLog>(db).await,
421 }
422 }
423
424 pub async fn get_w_db(&self) -> Option<Database> {
425 self.get_db(false).await
426 }
427
428 pub async fn get_r_db(&self) -> Option<Database> {
429 self.get_db(true).await
430 }
431
432 async fn get_db(&self, read: bool) -> Option<Database> {
433 let db_name = if read {
434 &self.db_r_name
435 } else {
436 &self.db_w_name
437 };
438 let db = match database::get(&self.client_holder, db_name).await {
439 Ok(db) => Some(db),
440 Err(e) => {
441 log::error!("get_db: {:?}", e);
442 None
443 }
444 };
445 db
446 }
447
448 pub async fn update_transaction(
449 db: &Database,
450 item: &PositionLog,
451 ) -> Result<(), Box<dyn error::Error>> {
452 update_item(db, item).await?;
453 Ok(())
454 }
455
456 pub async fn update_price(db: &Database, item: PriceLog) -> Result<(), Box<dyn error::Error>> {
457 update_item(db, &item).await?;
458 Ok(())
459 }
460
461 pub async fn copy_price(db_r: &Database, db_w: &Database, limit: Option<u32>) {
462 let item = PriceLog::default();
463 let items = {
464 match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
465 Ok(items) => items,
466 Err(e) => {
467 log::error!("get price: {:?}", e);
468 return;
469 }
470 }
471 };
472 log::debug!("get prices: num = {}", items.len());
473
474 for item in &items {
475 match insert_item(db_w, item).await {
476 Ok(_) => {}
477 Err(e) => {
478 log::error!("write price: {:?}", e);
479 return;
480 }
481 }
482 }
483 }
484
485 pub async fn copy_position(db_r: &Database, db_w: &Database, limit: Option<u32>) {
486 let item = PositionLog::default();
487 let items = {
488 match search_items(db_r, &item, SearchMode::Ascending, limit, None, Some("id")).await {
489 Ok(items) => items,
490 Err(e) => {
491 log::error!("get position: {:?}", e);
492 return;
493 }
494 }
495 };
496 log::debug!("get position: num = {}", items.len());
497
498 for item in &items {
499 match insert_item(db_w, item).await {
500 Ok(_) => {}
501 Err(e) => {
502 log::error!("write position: {:?}", e);
503 return;
504 }
505 }
506 }
507 }
508
509 pub async fn get_price_market_data(
510 db: &Database,
511 limit: Option<u32>,
512 id: Option<u32>,
513 is_ascend: bool,
514 ) -> HashMap<String, HashMap<String, Vec<PricePoint>>> {
515 let search_mode = if is_ascend {
516 SearchMode::Ascending
517 } else {
518 SearchMode::Descending
519 };
520 let sort_key = Some("price_point.timestamp");
521 let item = PriceLog::default();
522
523 let items = match id {
524 Some(id) => search_item(db, &item, Some(id), sort_key)
525 .await
526 .map(|item| vec![item]),
527 None => search_items(db, &item, search_mode, limit, None, sort_key).await,
528 };
529
530 let Ok(mut items) = items else {
531 log::warn!("get_price_market_data: search failed");
532 return HashMap::new();
533 };
534
535 items.sort_by_key(|p| p.price_point.timestamp);
536
537 let mut result = HashMap::new();
538 for price_log in items {
539 result
540 .entry(price_log.name)
541 .or_insert_with(HashMap::new)
542 .entry(price_log.token_name)
543 .or_insert_with(Vec::new)
544 .push(price_log.price_point);
545 }
546
547 result
548 }
549
550 pub async fn get_all_positions(
551 db: &Database,
552 limit: Option<u32>,
553 id: Option<u32>,
554 is_ascend: bool,
555 ) -> Vec<PositionLog> {
556 let search_mode = if is_ascend {
557 SearchMode::Ascending
558 } else {
559 SearchMode::Descending
560 };
561 let sort_key = Some("open_timestamp");
562 let item = PositionLog::default();
563
564 let items = if let Some(id) = id {
565 match search_item(db, &item, Some(id), sort_key).await {
566 Ok(position) => vec![position],
567 Err(e) => {
568 log::warn!("get_all_positions: {:?}", e);
569 vec![]
570 }
571 }
572 } else {
573 match search_items(db, &item, search_mode, limit, None, sort_key).await {
574 Ok(positions) => positions,
575 Err(e) => {
576 log::warn!("get_all_positions: {:?}", e);
577 vec![]
578 }
579 }
580 };
581
582 items
583 }
584
585 async fn delete_all_positions(db: &Database) -> Result<(), Box<dyn error::Error>> {
586 let item = PositionLog::default();
587 delete_item_all(db, &item).await
588 }
589
590 pub async fn insert_pnl(db: &Database, item: PnlLog) -> Result<(), Box<dyn error::Error>> {
591 insert_item(db, &item).await?;
592 Ok(())
593 }
594
595 pub async fn get_app_state(db: &Database) -> AppState {
596 let item = AppState::default();
597 match search_item(db, &item, Some(1), Some("id")).await {
598 Ok(item) => item,
599 Err(e) => {
600 log::warn!("get_app_state: {:?}", e);
601 item
602 }
603 }
604 }
605
606 async fn delete_app_state(db: &Database) -> Result<(), Box<dyn error::Error>> {
607 let item = AppState::default();
608 delete_item_all(db, &item).await
609 }
610
611 pub async fn update_app_state(
612 db: &Database,
613 last_execution_time: Option<SystemTime>,
614 last_equity: Option<Decimal>,
615 ave_dd: Option<Decimal>,
616 max_dd: Option<Decimal>,
617 cumulative_return: Option<Decimal>,
618 cumulative_dd: Option<Decimal>,
619 score: Option<Decimal>,
620 score_2: Option<Decimal>,
621 score_3: Option<Decimal>,
622 curcuit_break: bool,
623 error_time: Option<String>,
624 max_invested_amount: Option<Decimal>,
625 fund_configs: Option<Vec<FundConfig>>,
626 ) -> Result<(), Box<dyn error::Error>> {
627 let item = AppState::default();
628 let mut item = match search_item(db, &item, Some(1), Some("id")).await {
629 Ok(prev_item) => prev_item,
630 Err(_) => item,
631 };
632
633 if last_execution_time.is_some() {
634 item.last_execution_time = last_execution_time;
635 }
636
637 if let Some(last_equity) = last_equity {
638 item.last_equity = Some(last_equity.round());
639 }
640
641 if let Some(ave_dd) = ave_dd {
642 item.ave_dd = Some(ave_dd.round());
643 }
644
645 if let Some(max_dd_val) = max_dd {
646 if item
647 .max_dd
648 .map_or(true, |item_max_dd| max_dd_val > item_max_dd)
649 {
650 item.max_dd = Some(max_dd_val.round());
651 }
652 }
653
654 if let Some(cumulative_return) = cumulative_return {
655 item.cumulative_return += cumulative_return.round();
656 }
657
658 if let Some(cumulative_dd) = cumulative_dd {
659 item.cumulative_dd += cumulative_dd.round();
660 }
661
662 if score.is_some() {
663 item.score = score;
664 }
665
666 if score_2.is_some() {
667 item.score_2 = score_2;
668 }
669
670 if score_3.is_some() {
671 item.score_3 = score_3;
672 }
673
674 item.curcuit_break = curcuit_break;
675
676 if let Some(error_time) = error_time {
677 item.error_time.push(error_time);
678 }
679
680 if let Some(max_invested_amount) = max_invested_amount {
681 item.max_invested_amount = max_invested_amount.round();
682 }
683
684 if let Some(fund_configs) = fund_configs {
685 item.fund_configs = Some(fund_configs);
686 }
687
688 update_item(db, &item).await?;
689 Ok(())
690 }
691
692 pub fn db_w_name(&self) -> &str {
693 &self.db_w_name
694 }
695}
696
697#[derive(Clone)]
698pub struct ModelParams {
699 db_name: String,
700 client_holder: Arc<Mutex<ClientHolder>>,
701 collection_name: String,
702 save_to_db: bool,
703 file_path: Option<String>,
704}
705
706impl ModelParams {
707 pub async fn new(
708 mongodb_uri: &str,
709 db_name: &str,
710 save_to_db: bool,
711 file_path: Option<String>,
712 ) -> Self {
713 let mut client_options = match ClientOptions::parse(mongodb_uri).await {
715 Ok(client_options) => client_options,
716 Err(e) => {
717 panic!("{:?}", e);
718 }
719 };
720 let tls_options = TlsOptions::builder().build();
721 client_options.tls = Some(Tls::Enabled(tls_options));
722 let client_holder = Arc::new(Mutex::new(ClientHolder::new(client_options)));
723
724 ModelParams {
725 db_name: db_name.to_owned(),
726 client_holder,
727 collection_name: "model_params".to_owned(),
728 save_to_db,
729 file_path,
730 }
731 }
732
733 async fn get_db(&self) -> Option<Database> {
734 let db = match database::get(&self.client_holder, &self.db_name).await {
735 Ok(db) => Some(db),
736 Err(e) => {
737 log::error!("get_db: {:?}", e);
738 None
739 }
740 };
741 db
742 }
743
744 pub async fn save_model(
745 &self,
746 key: &str,
747 model: &SerializableModel,
748 ) -> Result<(), Box<dyn std::error::Error>> {
749 if self.save_to_db {
750 self.save_model_to_db(key, model).await
751 } else {
752 self.save_model_to_file(key, model).await
753 }
754 }
755
756 pub async fn load_model(
757 &self,
758 key: &str,
759 ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
760 if self.save_to_db {
761 self.load_model_from_db(key).await
762 } else {
763 self.load_model_from_file(key).await
764 }
765 }
766
767 async fn save_model_to_db(
768 &self,
769 key: &str,
770 model: &SerializableModel,
771 ) -> Result<(), Box<dyn std::error::Error>> {
772 let db = self.get_db().await.ok_or("no db")?;
773 let collection: Collection<Document> = db.collection(&self.collection_name);
774 let serialized_model = bincode::serialize(model)?;
775
776 let document = doc! {
777 "key": key,
778 "model": Bson::Binary(mongodb::bson::Binary {
779 subtype: mongodb::bson::spec::BinarySubtype::Generic,
780 bytes: serialized_model
781 })
782 };
783
784 collection
785 .update_one(
786 doc! { "key": key },
787 doc! { "$set": document },
788 mongodb::options::UpdateOptions::builder()
789 .upsert(true)
790 .build(),
791 )
792 .await?;
793 Ok(())
794 }
795
796 async fn load_model_from_db(
797 &self,
798 key: &str,
799 ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
800 let db = self.get_db().await.ok_or("no db")?;
801 let collection: Collection<Document> = db.collection(&self.collection_name);
802
803 let filter = doc! { "key": key };
804 let document = collection
805 .find_one(filter, None)
806 .await?
807 .ok_or("No model found in the collection")?;
808
809 if let Some(Bson::Binary(model_bytes)) = document.get("model") {
810 let model: SerializableModel = bincode::deserialize(&model_bytes.bytes)?;
811 Ok(model)
812 } else {
813 Err("Invalid data format".into())
814 }
815 }
816
817 async fn save_model_to_file(
818 &self,
819 key: &str,
820 model: &SerializableModel,
821 ) -> Result<(), Box<dyn std::error::Error>> {
822 let serialized_model = bincode::serialize(model)?;
823 let file_name = format!("{}.bin", key);
824
825 let file_path = if let Some(ref dir) = self.file_path {
826 Path::new(dir).join(file_name)
827 } else {
828 Path::new(&file_name).to_path_buf()
829 };
830
831 let mut file = File::create(&file_path)?;
832 file.write_all(&serialized_model)?;
833 Ok(())
834 }
835
836 async fn load_model_from_file(
837 &self,
838 key: &str,
839 ) -> Result<SerializableModel, Box<dyn std::error::Error>> {
840 let file_name = format!("{}.bin", key);
841
842 let file_path = if let Some(ref dir) = self.file_path {
843 Path::new(dir).join(file_name)
844 } else {
845 Path::new(&file_name).to_path_buf()
846 };
847
848 let mut file = File::open(&file_path).map_err(|e| {
849 log::error!("Failed to open file: {:?}: {}", file_path, e);
850 e
851 })?;
852 let mut buffer = Vec::new();
853 file.read_to_end(&mut buffer)?;
854 let model: SerializableModel = bincode::deserialize(&buffer)?;
855 Ok(model)
856 }
857}