heat_sdk/command/
train.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use burn::{config::Config, module::Module, tensor::backend::Backend};

use crate::client::HeatClient;

#[derive(Debug, Clone)]
pub struct MultiDevice<B: Backend>(pub Vec<B::Device>);

#[derive(Debug, Clone)]
pub struct TrainCommandContext<B: Backend> {
    client: HeatClient,
    devices: Vec<B::Device>,
    config: String,
}

impl<B: Backend> TrainCommandContext<B> {
    pub fn new(client: HeatClient, devices: Vec<B::Device>, config: String) -> Self {
        Self {
            client,
            devices,
            config,
        }
    }

    pub fn client(&mut self) -> &mut HeatClient {
        &mut self.client
    }

    pub fn devices(&mut self) -> &mut Vec<B::Device> {
        &mut self.devices
    }

    pub fn config(&self) -> &str {
        &self.config
    }
}

trait FromTrainCommandContext<B: Backend> {
    fn from_context(context: &TrainCommandContext<B>) -> Self;
}

impl<B: Backend> FromTrainCommandContext<B> for HeatClient {
    fn from_context(context: &TrainCommandContext<B>) -> Self {
        context.client.clone()
    }
}

impl<B: Backend> FromTrainCommandContext<B> for MultiDevice<B> {
    fn from_context(context: &TrainCommandContext<B>) -> Self {
        MultiDevice(context.devices.clone())
    }
}

impl<B: Backend> IntoIterator for MultiDevice<B> {
    type Item = B::Device;
    type IntoIter = std::vec::IntoIter<Self::Item>;

    fn into_iter(self) -> Self::IntoIter {
        self.0.into_iter()
    }
}

impl<B: Backend, T: Config> FromTrainCommandContext<B> for T {
    fn from_context(context: &TrainCommandContext<B>) -> Self {
        T::load_binary(context.config.as_bytes()).expect("Config should be loaded")
    }
}

pub trait TrainCommandHandler<B: Backend, T, M: Module<B>, E: Into<Box<dyn std::error::Error>>> {
    fn call(self, context: TrainCommandContext<B>) -> Result<M, E>;
}

impl<F, M, B, E: Into<Box<dyn std::error::Error>>> TrainCommandHandler<B, (), M, E> for F
where
    F: Fn() -> Result<M, E>,
    M: Module<B>,
    B: Backend,
{
    fn call(self, _context: TrainCommandContext<B>) -> Result<M, E> {
        (self)()
    }
}

impl<F, T, M, B, E: Into<Box<dyn std::error::Error>>> TrainCommandHandler<B, (T,), M, E> for F
where
    F: Fn(T) -> Result<M, E>,
    T: FromTrainCommandContext<B>,
    M: Module<B>,
    B: Backend,
{
    fn call(self, context: TrainCommandContext<B>) -> Result<M, E> {
        (self)(T::from_context(&context))
    }
}

impl<F, T1, T2, M, B, E: Into<Box<dyn std::error::Error>>> TrainCommandHandler<B, (T1, T2), M, E>
    for F
where
    F: Fn(T1, T2) -> Result<M, E>,
    T1: FromTrainCommandContext<B>,
    T2: FromTrainCommandContext<B>,
    M: Module<B>,
    B: Backend,
{
    fn call(self, context: TrainCommandContext<B>) -> Result<M, E> {
        (self)(T1::from_context(&context), T2::from_context(&context))
    }
}

impl<F, T1, T2, T3, M, B, E: Into<Box<dyn std::error::Error>>>
    TrainCommandHandler<B, (T1, T2, T3), M, E> for F
where
    F: Fn(T1, T2, T3) -> Result<M, E>,
    T1: FromTrainCommandContext<B>,
    T2: FromTrainCommandContext<B>,
    T3: FromTrainCommandContext<B>,
    M: Module<B>,
    B: Backend,
{
    fn call(self, context: TrainCommandContext<B>) -> Result<M, E> {
        (self)(
            T1::from_context(&context),
            T2::from_context(&context),
            T3::from_context(&context),
        )
    }
}