flame_rs/client/
mod.rs

1/*
2Copyright 2023 The Flame Authors.
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6    http://www.apache.org/licenses/LICENSE-2.0
7Unless required by applicable law or agreed to in writing, software
8distributed under the License is distributed on an "AS IS" BASIS,
9WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10See the License for the specific language governing permissions and
11limitations under the License.
12*/
13
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16
17use chrono::{DateTime, Duration, TimeZone, Utc};
18use futures::TryFutureExt;
19// use serde::{Deserialize, Serialize};
20use serde_derive::{Deserialize, Serialize};
21use stdng::{logs::TraceFn, trace_fn};
22use tokio_stream::StreamExt;
23use tonic::transport::Channel;
24use tonic::transport::Endpoint;
25use tonic::Request;
26
27use self::rpc::frontend_client::FrontendClient as FlameFrontendClient;
28use self::rpc::{
29    ApplicationSpec, CloseSessionRequest, CreateSessionRequest, CreateTaskRequest, Environment,
30    GetApplicationRequest, GetSessionRequest, GetTaskRequest, ListApplicationRequest,
31    ListExecutorRequest, ListSessionRequest, ListTaskRequest, RegisterApplicationRequest,
32    SessionSpec, TaskSpec, UnregisterApplicationRequest, UpdateApplicationRequest,
33    WatchTaskRequest,
34};
35use crate::apis::flame as rpc;
36use crate::apis::Shim;
37use crate::apis::{
38    ApplicationID, ApplicationState, CommonData, ExecutorState, FlameError, SessionID,
39    SessionState, TaskID, TaskInput, TaskOutput, TaskState,
40};
41use crate::lock_ptr;
42
43type FlameClient = FlameFrontendClient<Channel>;
44
45pub async fn connect(addr: &str) -> Result<Connection, FlameError> {
46    let endpoint = Endpoint::from_shared(addr.to_string())
47        .map_err(|_| FlameError::InvalidConfig("invalid address".to_string()))?;
48
49    let channel = endpoint
50        .connect()
51        .await
52        .map_err(|_| FlameError::InvalidConfig("failed to connect".to_string()))?;
53
54    Ok(Connection { channel })
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct Event {
59    pub code: i32,
60    pub message: Option<String>,
61    #[serde(with = "serde_utc")]
62    pub creation_time: DateTime<Utc>,
63}
64
65#[derive(Clone)]
66pub struct Connection {
67    pub(crate) channel: Channel,
68}
69
70#[derive(Clone, Serialize, Deserialize)]
71pub struct SessionAttributes {
72    pub application: String,
73    pub slots: u32,
74    #[serde(with = "serde_message")]
75    pub common_data: Option<CommonData>,
76}
77
78#[derive(Clone, Serialize, Deserialize)]
79pub struct ApplicationSchema {
80    pub input: Option<String>,
81    pub output: Option<String>,
82    pub common_data: Option<String>,
83}
84
85#[derive(Clone, Serialize, Deserialize)]
86pub struct ApplicationAttributes {
87    pub shim: Shim,
88
89    pub image: Option<String>,
90    pub description: Option<String>,
91    pub labels: Vec<String>,
92    pub command: Option<String>,
93    pub arguments: Vec<String>,
94    pub environments: HashMap<String, String>,
95    pub working_directory: Option<String>,
96    pub max_instances: Option<u32>,
97    #[serde(with = "serde_duration")]
98    pub delay_release: Option<Duration>,
99    pub schema: Option<ApplicationSchema>,
100}
101
102#[derive(Clone, Serialize, Deserialize)]
103pub struct Application {
104    pub name: ApplicationID,
105
106    pub attributes: ApplicationAttributes,
107
108    pub state: ApplicationState,
109    #[serde(with = "serde_utc")]
110    pub creation_time: DateTime<Utc>,
111}
112
113#[derive(Clone, Serialize, Deserialize)]
114pub struct Executor {
115    pub id: String,
116    pub state: ExecutorState,
117    pub session_id: Option<String>,
118    pub slots: u32,
119    pub node: String,
120}
121
122#[derive(Clone, Serialize, Deserialize)]
123pub struct Session {
124    #[serde(skip)]
125    pub(crate) client: Option<FlameClient>,
126
127    pub id: SessionID,
128    pub slots: u32,
129    pub application: String,
130    #[serde(with = "serde_utc")]
131    pub creation_time: DateTime<Utc>,
132
133    pub state: SessionState,
134    pub pending: i32,
135    pub running: i32,
136    pub succeed: i32,
137    pub failed: i32,
138
139    pub events: Vec<Event>,
140    pub tasks: Option<Vec<Task>>,
141}
142
143#[derive(Clone, Serialize, Deserialize)]
144pub struct Task {
145    pub id: TaskID,
146    pub ssn_id: SessionID,
147
148    pub state: TaskState,
149
150    #[serde(with = "serde_message")]
151    pub input: Option<TaskInput>,
152    #[serde(with = "serde_message")]
153    pub output: Option<TaskOutput>,
154
155    pub events: Vec<Event>,
156}
157
158pub type TaskInformerPtr = Arc<Mutex<dyn TaskInformer>>;
159
160pub trait TaskInformer: Send + Sync + 'static {
161    fn on_update(&mut self, task: Task);
162    fn on_error(&mut self, e: FlameError);
163}
164
165impl Task {
166    pub fn is_completed(&self) -> bool {
167        self.state == TaskState::Succeed || self.state == TaskState::Failed
168    }
169
170    pub fn is_succeed(&self) -> bool {
171        self.state == TaskState::Succeed
172    }
173
174    pub fn is_failed(&self) -> bool {
175        self.state == TaskState::Failed
176    }
177}
178
179impl Connection {
180    pub async fn create_session(&self, attrs: &SessionAttributes) -> Result<Session, FlameError> {
181        trace_fn!("Connection::create_session");
182
183        let create_ssn_req = CreateSessionRequest {
184            session: Some(SessionSpec {
185                application: attrs.application.clone(),
186                slots: attrs.slots,
187                common_data: attrs.common_data.clone().map(CommonData::into),
188            }),
189        };
190
191        let mut client = FlameClient::new(self.channel.clone());
192        let ssn = client.create_session(create_ssn_req).await?;
193        let ssn = ssn.into_inner();
194
195        let mut ssn = Session::from(&ssn);
196        ssn.client = Some(client);
197
198        Ok(ssn)
199    }
200
201    pub async fn list_session(&self) -> Result<Vec<Session>, FlameError> {
202        let mut client = FlameClient::new(self.channel.clone());
203        let ssn_list = client.list_session(ListSessionRequest {}).await?;
204
205        Ok(ssn_list
206            .into_inner()
207            .sessions
208            .iter()
209            .map(Session::from)
210            .collect())
211    }
212
213    pub async fn get_session(&self, id: &SessionID) -> Result<Session, FlameError> {
214        let mut client = FlameClient::new(self.channel.clone());
215        let ssn = client
216            .get_session(GetSessionRequest {
217                session_id: id.to_string(),
218            })
219            .await?;
220
221        let ssn = ssn.into_inner();
222        let mut ssn = Session::from(&ssn);
223        ssn.client = Some(client);
224
225        Ok(ssn)
226    }
227
228    pub async fn register_application(
229        &self,
230        name: String,
231        app: ApplicationAttributes,
232    ) -> Result<(), FlameError> {
233        let mut client = FlameClient::new(self.channel.clone());
234
235        let req = RegisterApplicationRequest {
236            name,
237            application: Some(ApplicationSpec::from(app)),
238        };
239
240        let res = client
241            .register_application(Request::new(req))
242            .await?
243            .into_inner();
244
245        if res.return_code < 0 {
246            Err(FlameError::Network(res.message.unwrap_or_default()))
247        } else {
248            Ok(())
249        }
250    }
251
252    pub async fn update_application(
253        &self,
254        name: String,
255        app: ApplicationAttributes,
256    ) -> Result<(), FlameError> {
257        let mut client = FlameClient::new(self.channel.clone());
258
259        let req = UpdateApplicationRequest {
260            name,
261            application: Some(ApplicationSpec::from(app)),
262        };
263
264        let res = client
265            .update_application(Request::new(req))
266            .await?
267            .into_inner();
268
269        if res.return_code < 0 {
270            Err(FlameError::Network(res.message.unwrap_or_default()))
271        } else {
272            Ok(())
273        }
274    }
275
276    pub async fn unregister_application(&self, name: String) -> Result<(), FlameError> {
277        let mut client = FlameClient::new(self.channel.clone());
278
279        let req = UnregisterApplicationRequest { name };
280
281        let res = client
282            .unregister_application(Request::new(req))
283            .await?
284            .into_inner();
285
286        if res.return_code < 0 {
287            Err(FlameError::Network(res.message.unwrap_or_default()))
288        } else {
289            Ok(())
290        }
291    }
292
293    pub async fn list_application(&self) -> Result<Vec<Application>, FlameError> {
294        let mut client = FlameClient::new(self.channel.clone());
295        let app_list = client.list_application(ListApplicationRequest {}).await?;
296
297        Ok(app_list
298            .into_inner()
299            .applications
300            .iter()
301            .map(Application::from)
302            .collect())
303    }
304
305    pub async fn get_application(&self, name: &str) -> Result<Application, FlameError> {
306        let mut client = FlameClient::new(self.channel.clone());
307        let app = client
308            .get_application(GetApplicationRequest {
309                name: name.to_string(),
310            })
311            .await?;
312        Ok(Application::from(&app.into_inner()))
313    }
314
315    pub async fn list_executor(&self) -> Result<Vec<Executor>, FlameError> {
316        let mut client = FlameClient::new(self.channel.clone());
317        let executor_list = client.list_executor(ListExecutorRequest {}).await?;
318        Ok(executor_list
319            .into_inner()
320            .executors
321            .iter()
322            .map(Executor::from)
323            .collect())
324    }
325}
326
327impl Session {
328    pub async fn create_task(&self, input: Option<TaskInput>) -> Result<Task, FlameError> {
329        trace_fn!("Session::create_task");
330        let mut client = self
331            .client
332            .clone()
333            .ok_or(FlameError::Internal("no flame client".to_string()))?;
334
335        let create_task_req = CreateTaskRequest {
336            task: Some(TaskSpec {
337                session_id: self.id.clone(),
338                input: input.map(|input| input.to_vec()),
339                output: None,
340            }),
341        };
342
343        let task = client.create_task(create_task_req).await?;
344
345        let task = task.into_inner();
346        Ok(Task::from(&task))
347    }
348
349    pub async fn get_task(&self, id: &TaskID) -> Result<Task, FlameError> {
350        trace_fn!("Session::get_task");
351        let mut client = self
352            .client
353            .clone()
354            .ok_or(FlameError::Internal("no flame client".to_string()))?;
355
356        let get_task_req = GetTaskRequest {
357            session_id: self.id.clone(),
358            task_id: id.clone(),
359        };
360        let task = client.get_task(get_task_req).await?;
361
362        let task = task.into_inner();
363        Ok(Task::from(&task))
364    }
365
366    pub async fn list_tasks(&self) -> Result<Vec<Task>, FlameError> {
367        // TODO (k82cn): Add top n tasks to avoid memory overflow.
368        trace_fn!("Session::list_task");
369        let mut client = self
370            .client
371            .clone()
372            .ok_or(FlameError::Internal("no flame client".to_string()))?;
373        let task_stream = client
374            .list_task(Request::new(ListTaskRequest {
375                session_id: self.id.to_string(),
376            }))
377            .await?;
378
379        let mut task_list = vec![];
380
381        let mut task_stream = task_stream.into_inner();
382        while let Some(task) = task_stream.next().await {
383            if let Ok(t) = task {
384                task_list.push(Task::from(&t));
385            }
386        }
387
388        Ok(task_list)
389    }
390
391    pub async fn run_task(
392        &self,
393        input: Option<TaskInput>,
394        informer_ptr: TaskInformerPtr,
395    ) -> Result<(), FlameError> {
396        trace_fn!("Session::run_task");
397        self.create_task(input)
398            .and_then(|task| self.watch_task(task.ssn_id.clone(), task.id, informer_ptr))
399            .await
400    }
401
402    pub async fn watch_task(
403        &self,
404        session_id: SessionID,
405        task_id: TaskID,
406        informer_ptr: TaskInformerPtr,
407    ) -> Result<(), FlameError> {
408        trace_fn!("Session::watch_task");
409        let mut client = self
410            .client
411            .clone()
412            .ok_or(FlameError::Internal("no flame client".to_string()))?;
413
414        let watch_task_req = WatchTaskRequest {
415            session_id,
416            task_id,
417        };
418        let mut task_stream = client.watch_task(watch_task_req).await?.into_inner();
419        while let Some(task) = task_stream.next().await {
420            match task {
421                Ok(t) => {
422                    let mut informer = lock_ptr!(informer_ptr)?;
423                    informer.on_update(Task::from(&t));
424                }
425                Err(e) => {
426                    let mut informer = lock_ptr!(informer_ptr)?;
427                    informer.on_error(FlameError::from(e.clone()));
428                }
429            }
430        }
431        Ok(())
432    }
433
434    pub async fn close(&self) -> Result<(), FlameError> {
435        trace_fn!("Session::close");
436        let mut client = self
437            .client
438            .clone()
439            .ok_or(FlameError::Internal("no flame client".to_string()))?;
440
441        let close_ssn_req = CloseSessionRequest {
442            session_id: self.id.clone(),
443        };
444
445        client.close_session(close_ssn_req).await?;
446
447        Ok(())
448    }
449}
450
451impl From<&rpc::Task> for Task {
452    fn from(task: &rpc::Task) -> Self {
453        let metadata = task.metadata.clone().unwrap();
454        let spec = task.spec.clone().unwrap();
455        let status = task.status.clone().unwrap();
456        Task {
457            id: metadata.id,
458            ssn_id: spec.session_id.clone(),
459            input: spec.input.map(TaskInput::from),
460            output: spec.output.map(TaskOutput::from),
461            state: TaskState::try_from(status.state).unwrap_or(TaskState::default()),
462            events: status.events.clone().into_iter().map(Event::from).collect(),
463        }
464    }
465}
466
467impl From<&rpc::Session> for Session {
468    fn from(ssn: &rpc::Session) -> Self {
469        let metadata = ssn.metadata.clone().unwrap();
470        let status = ssn.status.clone().unwrap();
471        let spec = ssn.spec.clone().unwrap();
472
473        let naivedatetime_utc =
474            DateTime::from_timestamp_millis(status.creation_time * 1000).unwrap();
475        let creation_time = Utc.from_utc_datetime(&naivedatetime_utc.naive_utc());
476
477        Session {
478            client: None,
479            id: metadata.id,
480            slots: spec.slots,
481            application: spec.application,
482            creation_time,
483            state: SessionState::try_from(status.state).unwrap_or(SessionState::default()),
484            pending: status.pending,
485            running: status.running,
486            succeed: status.succeed,
487            failed: status.failed,
488            events: status.events.clone().into_iter().map(Event::from).collect(),
489            tasks: None,
490        }
491    }
492}
493
494impl From<&rpc::Event> for Event {
495    fn from(event: &rpc::Event) -> Self {
496        let second = event.creation_time / 1000;
497        let nanosecond = ((event.creation_time % 1000) * 1_000_000) as u32;
498
499        Self {
500            code: event.code,
501            message: event.message.clone(),
502            creation_time: DateTime::from_timestamp(second, nanosecond).unwrap(),
503        }
504    }
505}
506
507impl From<rpc::Event> for Event {
508    fn from(event: rpc::Event) -> Self {
509        Event::from(&event)
510    }
511}
512
513impl From<&rpc::Application> for Application {
514    fn from(app: &rpc::Application) -> Self {
515        let metadata = app.metadata.clone().unwrap();
516        let spec = app.spec.clone().unwrap();
517        let status = app.status.unwrap();
518
519        let naivedatetime_utc =
520            DateTime::from_timestamp_millis(status.creation_time * 1000).unwrap();
521        let creation_time = Utc.from_utc_datetime(&naivedatetime_utc.naive_utc());
522
523        Self {
524            name: metadata.name,
525            attributes: ApplicationAttributes::from(spec),
526            state: ApplicationState::from(status.state()),
527            creation_time,
528        }
529    }
530}
531
532impl From<ApplicationAttributes> for ApplicationSpec {
533    fn from(app: ApplicationAttributes) -> Self {
534        Self {
535            shim: app.shim.into(),
536            image: app.image.clone(),
537            description: app.description.clone(),
538            labels: app.labels.clone(),
539            command: app.command.clone(),
540            arguments: app.arguments.clone(),
541            environments: app
542                .environments
543                .clone()
544                .into_iter()
545                .map(|(key, value)| Environment { name: key, value })
546                .collect(),
547            working_directory: app.working_directory.clone(),
548            max_instances: app.max_instances,
549            delay_release: app.delay_release.map(|s| s.num_seconds()),
550            schema: app.schema.clone().map(rpc::ApplicationSchema::from),
551        }
552    }
553}
554
555impl From<ApplicationSpec> for ApplicationAttributes {
556    fn from(app: ApplicationSpec) -> Self {
557        Self {
558            shim: app.shim().into(),
559            image: app.image.clone(),
560            description: app.description.clone(),
561            labels: app.labels.clone(),
562            command: app.command.clone(),
563            arguments: app.arguments.clone(),
564            environments: app
565                .environments
566                .clone()
567                .into_iter()
568                .map(|env| (env.name, env.value))
569                .collect(),
570            working_directory: app.working_directory.clone(),
571            max_instances: app.max_instances,
572            delay_release: app.delay_release.map(Duration::seconds),
573            schema: app.schema.clone().map(ApplicationSchema::from),
574        }
575    }
576}
577
578impl From<ApplicationSchema> for rpc::ApplicationSchema {
579    fn from(schema: ApplicationSchema) -> Self {
580        Self {
581            input: schema.input,
582            output: schema.output,
583            common_data: schema.common_data,
584        }
585    }
586}
587
588impl From<rpc::ApplicationSchema> for ApplicationSchema {
589    fn from(schema: rpc::ApplicationSchema) -> Self {
590        Self {
591            input: schema.input,
592            output: schema.output,
593            common_data: schema.common_data,
594        }
595    }
596}
597
598impl From<&rpc::Executor> for Executor {
599    fn from(e: &rpc::Executor) -> Self {
600        let spec = e.spec.clone().unwrap();
601        let status = e.status.clone().unwrap();
602        let metadata = e.metadata.clone().unwrap();
603
604        let state = rpc::ExecutorState::try_from(status.state).unwrap().into();
605
606        Executor {
607            id: metadata.id,
608            session_id: status.session_id,
609            slots: spec.slots,
610            node: spec.node,
611            state,
612        }
613    }
614}
615
616impl From<rpc::Executor> for Executor {
617    fn from(e: rpc::Executor) -> Self {
618        Executor::from(&e)
619    }
620}
621
622mod serde_duration {
623    use chrono::Duration;
624    use serde::{Deserialize, Deserializer, Serializer};
625
626    pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
627    where
628        S: Serializer,
629    {
630        match duration {
631            Some(duration) => serializer.serialize_i64(duration.num_seconds()),
632            None => serializer.serialize_none(),
633        }
634    }
635
636    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
637    where
638        D: Deserializer<'de>,
639    {
640        let seconds = i64::deserialize(deserializer)?;
641        Ok(Some(Duration::seconds(seconds)))
642    }
643}
644
645mod serde_utc {
646    use chrono::{DateTime, Utc};
647    use serde::{self, Deserialize, Deserializer, Serializer};
648
649    pub fn serialize<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
650    where
651        S: Serializer,
652    {
653        serializer.serialize_i64(date.timestamp())
654    }
655
656    pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
657    where
658        D: Deserializer<'de>,
659    {
660        let timestamp = i64::deserialize(deserializer)?;
661        DateTime::<Utc>::from_timestamp(timestamp, 0)
662            .ok_or(serde::de::Error::custom("invalid timestamp"))
663    }
664}
665
666mod serde_message {
667    use bytes::Bytes;
668    use serde::{Deserialize, Deserializer, Serializer};
669
670    pub fn serialize<S>(message: &Option<Bytes>, serializer: S) -> Result<S::Ok, S::Error>
671    where
672        S: Serializer,
673    {
674        match message {
675            Some(message) => serializer.serialize_bytes(message),
676            None => serializer.serialize_none(),
677        }
678    }
679
680    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Bytes>, D::Error>
681    where
682        D: Deserializer<'de>,
683    {
684        let bytes = Vec::<u8>::deserialize(deserializer)?;
685        Ok(Some(Bytes::from(bytes)))
686    }
687}