1use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16
17use chrono::{DateTime, Duration, TimeZone, Utc};
18use futures::TryFutureExt;
19use serde_derive::{Deserialize, Serialize};
21use stdng::{logs::TraceFn, trace_fn};
22use tokio_stream::StreamExt;
23use tonic::transport::Channel;
24use tonic::transport::Endpoint;
25use tonic::Request;
26
27use self::rpc::frontend_client::FrontendClient as FlameFrontendClient;
28use self::rpc::{
29 ApplicationSpec, CloseSessionRequest, CreateSessionRequest, CreateTaskRequest, Environment,
30 GetApplicationRequest, GetSessionRequest, GetTaskRequest, ListApplicationRequest,
31 ListExecutorRequest, ListSessionRequest, ListTaskRequest, RegisterApplicationRequest,
32 SessionSpec, TaskSpec, UnregisterApplicationRequest, UpdateApplicationRequest,
33 WatchTaskRequest,
34};
35use crate::apis::flame as rpc;
36use crate::apis::Shim;
37use crate::apis::{
38 ApplicationID, ApplicationState, CommonData, ExecutorState, FlameError, SessionID,
39 SessionState, TaskID, TaskInput, TaskOutput, TaskState,
40};
41use crate::lock_ptr;
42
43type FlameClient = FlameFrontendClient<Channel>;
44
45pub async fn connect(addr: &str) -> Result<Connection, FlameError> {
46 let endpoint = Endpoint::from_shared(addr.to_string())
47 .map_err(|_| FlameError::InvalidConfig("invalid address".to_string()))?;
48
49 let channel = endpoint
50 .connect()
51 .await
52 .map_err(|_| FlameError::InvalidConfig("failed to connect".to_string()))?;
53
54 Ok(Connection { channel })
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct Event {
59 pub code: i32,
60 pub message: Option<String>,
61 #[serde(with = "serde_utc")]
62 pub creation_time: DateTime<Utc>,
63}
64
65#[derive(Clone)]
66pub struct Connection {
67 pub(crate) channel: Channel,
68}
69
70#[derive(Clone, Serialize, Deserialize)]
71pub struct SessionAttributes {
72 pub application: String,
73 pub slots: u32,
74 #[serde(with = "serde_message")]
75 pub common_data: Option<CommonData>,
76}
77
78#[derive(Clone, Serialize, Deserialize)]
79pub struct ApplicationSchema {
80 pub input: Option<String>,
81 pub output: Option<String>,
82 pub common_data: Option<String>,
83}
84
85#[derive(Clone, Serialize, Deserialize)]
86pub struct ApplicationAttributes {
87 pub shim: Shim,
88
89 pub image: Option<String>,
90 pub description: Option<String>,
91 pub labels: Vec<String>,
92 pub command: Option<String>,
93 pub arguments: Vec<String>,
94 pub environments: HashMap<String, String>,
95 pub working_directory: Option<String>,
96 pub max_instances: Option<u32>,
97 #[serde(with = "serde_duration")]
98 pub delay_release: Option<Duration>,
99 pub schema: Option<ApplicationSchema>,
100}
101
102#[derive(Clone, Serialize, Deserialize)]
103pub struct Application {
104 pub name: ApplicationID,
105
106 pub attributes: ApplicationAttributes,
107
108 pub state: ApplicationState,
109 #[serde(with = "serde_utc")]
110 pub creation_time: DateTime<Utc>,
111}
112
113#[derive(Clone, Serialize, Deserialize)]
114pub struct Executor {
115 pub id: String,
116 pub state: ExecutorState,
117 pub session_id: Option<String>,
118 pub slots: u32,
119 pub node: String,
120}
121
122#[derive(Clone, Serialize, Deserialize)]
123pub struct Session {
124 #[serde(skip)]
125 pub(crate) client: Option<FlameClient>,
126
127 pub id: SessionID,
128 pub slots: u32,
129 pub application: String,
130 #[serde(with = "serde_utc")]
131 pub creation_time: DateTime<Utc>,
132
133 pub state: SessionState,
134 pub pending: i32,
135 pub running: i32,
136 pub succeed: i32,
137 pub failed: i32,
138
139 pub events: Vec<Event>,
140 pub tasks: Option<Vec<Task>>,
141}
142
143#[derive(Clone, Serialize, Deserialize)]
144pub struct Task {
145 pub id: TaskID,
146 pub ssn_id: SessionID,
147
148 pub state: TaskState,
149
150 #[serde(with = "serde_message")]
151 pub input: Option<TaskInput>,
152 #[serde(with = "serde_message")]
153 pub output: Option<TaskOutput>,
154
155 pub events: Vec<Event>,
156}
157
158pub type TaskInformerPtr = Arc<Mutex<dyn TaskInformer>>;
159
160pub trait TaskInformer: Send + Sync + 'static {
161 fn on_update(&mut self, task: Task);
162 fn on_error(&mut self, e: FlameError);
163}
164
165impl Task {
166 pub fn is_completed(&self) -> bool {
167 self.state == TaskState::Succeed || self.state == TaskState::Failed
168 }
169
170 pub fn is_succeed(&self) -> bool {
171 self.state == TaskState::Succeed
172 }
173
174 pub fn is_failed(&self) -> bool {
175 self.state == TaskState::Failed
176 }
177}
178
179impl Connection {
180 pub async fn create_session(&self, attrs: &SessionAttributes) -> Result<Session, FlameError> {
181 trace_fn!("Connection::create_session");
182
183 let create_ssn_req = CreateSessionRequest {
184 session: Some(SessionSpec {
185 application: attrs.application.clone(),
186 slots: attrs.slots,
187 common_data: attrs.common_data.clone().map(CommonData::into),
188 }),
189 };
190
191 let mut client = FlameClient::new(self.channel.clone());
192 let ssn = client.create_session(create_ssn_req).await?;
193 let ssn = ssn.into_inner();
194
195 let mut ssn = Session::from(&ssn);
196 ssn.client = Some(client);
197
198 Ok(ssn)
199 }
200
201 pub async fn list_session(&self) -> Result<Vec<Session>, FlameError> {
202 let mut client = FlameClient::new(self.channel.clone());
203 let ssn_list = client.list_session(ListSessionRequest {}).await?;
204
205 Ok(ssn_list
206 .into_inner()
207 .sessions
208 .iter()
209 .map(Session::from)
210 .collect())
211 }
212
213 pub async fn get_session(&self, id: &SessionID) -> Result<Session, FlameError> {
214 let mut client = FlameClient::new(self.channel.clone());
215 let ssn = client
216 .get_session(GetSessionRequest {
217 session_id: id.to_string(),
218 })
219 .await?;
220
221 let ssn = ssn.into_inner();
222 let mut ssn = Session::from(&ssn);
223 ssn.client = Some(client);
224
225 Ok(ssn)
226 }
227
228 pub async fn register_application(
229 &self,
230 name: String,
231 app: ApplicationAttributes,
232 ) -> Result<(), FlameError> {
233 let mut client = FlameClient::new(self.channel.clone());
234
235 let req = RegisterApplicationRequest {
236 name,
237 application: Some(ApplicationSpec::from(app)),
238 };
239
240 let res = client
241 .register_application(Request::new(req))
242 .await?
243 .into_inner();
244
245 if res.return_code < 0 {
246 Err(FlameError::Network(res.message.unwrap_or_default()))
247 } else {
248 Ok(())
249 }
250 }
251
252 pub async fn update_application(
253 &self,
254 name: String,
255 app: ApplicationAttributes,
256 ) -> Result<(), FlameError> {
257 let mut client = FlameClient::new(self.channel.clone());
258
259 let req = UpdateApplicationRequest {
260 name,
261 application: Some(ApplicationSpec::from(app)),
262 };
263
264 let res = client
265 .update_application(Request::new(req))
266 .await?
267 .into_inner();
268
269 if res.return_code < 0 {
270 Err(FlameError::Network(res.message.unwrap_or_default()))
271 } else {
272 Ok(())
273 }
274 }
275
276 pub async fn unregister_application(&self, name: String) -> Result<(), FlameError> {
277 let mut client = FlameClient::new(self.channel.clone());
278
279 let req = UnregisterApplicationRequest { name };
280
281 let res = client
282 .unregister_application(Request::new(req))
283 .await?
284 .into_inner();
285
286 if res.return_code < 0 {
287 Err(FlameError::Network(res.message.unwrap_or_default()))
288 } else {
289 Ok(())
290 }
291 }
292
293 pub async fn list_application(&self) -> Result<Vec<Application>, FlameError> {
294 let mut client = FlameClient::new(self.channel.clone());
295 let app_list = client.list_application(ListApplicationRequest {}).await?;
296
297 Ok(app_list
298 .into_inner()
299 .applications
300 .iter()
301 .map(Application::from)
302 .collect())
303 }
304
305 pub async fn get_application(&self, name: &str) -> Result<Application, FlameError> {
306 let mut client = FlameClient::new(self.channel.clone());
307 let app = client
308 .get_application(GetApplicationRequest {
309 name: name.to_string(),
310 })
311 .await?;
312 Ok(Application::from(&app.into_inner()))
313 }
314
315 pub async fn list_executor(&self) -> Result<Vec<Executor>, FlameError> {
316 let mut client = FlameClient::new(self.channel.clone());
317 let executor_list = client.list_executor(ListExecutorRequest {}).await?;
318 Ok(executor_list
319 .into_inner()
320 .executors
321 .iter()
322 .map(Executor::from)
323 .collect())
324 }
325}
326
327impl Session {
328 pub async fn create_task(&self, input: Option<TaskInput>) -> Result<Task, FlameError> {
329 trace_fn!("Session::create_task");
330 let mut client = self
331 .client
332 .clone()
333 .ok_or(FlameError::Internal("no flame client".to_string()))?;
334
335 let create_task_req = CreateTaskRequest {
336 task: Some(TaskSpec {
337 session_id: self.id.clone(),
338 input: input.map(|input| input.to_vec()),
339 output: None,
340 }),
341 };
342
343 let task = client.create_task(create_task_req).await?;
344
345 let task = task.into_inner();
346 Ok(Task::from(&task))
347 }
348
349 pub async fn get_task(&self, id: &TaskID) -> Result<Task, FlameError> {
350 trace_fn!("Session::get_task");
351 let mut client = self
352 .client
353 .clone()
354 .ok_or(FlameError::Internal("no flame client".to_string()))?;
355
356 let get_task_req = GetTaskRequest {
357 session_id: self.id.clone(),
358 task_id: id.clone(),
359 };
360 let task = client.get_task(get_task_req).await?;
361
362 let task = task.into_inner();
363 Ok(Task::from(&task))
364 }
365
366 pub async fn list_tasks(&self) -> Result<Vec<Task>, FlameError> {
367 trace_fn!("Session::list_task");
369 let mut client = self
370 .client
371 .clone()
372 .ok_or(FlameError::Internal("no flame client".to_string()))?;
373 let task_stream = client
374 .list_task(Request::new(ListTaskRequest {
375 session_id: self.id.to_string(),
376 }))
377 .await?;
378
379 let mut task_list = vec![];
380
381 let mut task_stream = task_stream.into_inner();
382 while let Some(task) = task_stream.next().await {
383 if let Ok(t) = task {
384 task_list.push(Task::from(&t));
385 }
386 }
387
388 Ok(task_list)
389 }
390
391 pub async fn run_task(
392 &self,
393 input: Option<TaskInput>,
394 informer_ptr: TaskInformerPtr,
395 ) -> Result<(), FlameError> {
396 trace_fn!("Session::run_task");
397 self.create_task(input)
398 .and_then(|task| self.watch_task(task.ssn_id.clone(), task.id, informer_ptr))
399 .await
400 }
401
402 pub async fn watch_task(
403 &self,
404 session_id: SessionID,
405 task_id: TaskID,
406 informer_ptr: TaskInformerPtr,
407 ) -> Result<(), FlameError> {
408 trace_fn!("Session::watch_task");
409 let mut client = self
410 .client
411 .clone()
412 .ok_or(FlameError::Internal("no flame client".to_string()))?;
413
414 let watch_task_req = WatchTaskRequest {
415 session_id,
416 task_id,
417 };
418 let mut task_stream = client.watch_task(watch_task_req).await?.into_inner();
419 while let Some(task) = task_stream.next().await {
420 match task {
421 Ok(t) => {
422 let mut informer = lock_ptr!(informer_ptr)?;
423 informer.on_update(Task::from(&t));
424 }
425 Err(e) => {
426 let mut informer = lock_ptr!(informer_ptr)?;
427 informer.on_error(FlameError::from(e.clone()));
428 }
429 }
430 }
431 Ok(())
432 }
433
434 pub async fn close(&self) -> Result<(), FlameError> {
435 trace_fn!("Session::close");
436 let mut client = self
437 .client
438 .clone()
439 .ok_or(FlameError::Internal("no flame client".to_string()))?;
440
441 let close_ssn_req = CloseSessionRequest {
442 session_id: self.id.clone(),
443 };
444
445 client.close_session(close_ssn_req).await?;
446
447 Ok(())
448 }
449}
450
451impl From<&rpc::Task> for Task {
452 fn from(task: &rpc::Task) -> Self {
453 let metadata = task.metadata.clone().unwrap();
454 let spec = task.spec.clone().unwrap();
455 let status = task.status.clone().unwrap();
456 Task {
457 id: metadata.id,
458 ssn_id: spec.session_id.clone(),
459 input: spec.input.map(TaskInput::from),
460 output: spec.output.map(TaskOutput::from),
461 state: TaskState::try_from(status.state).unwrap_or(TaskState::default()),
462 events: status.events.clone().into_iter().map(Event::from).collect(),
463 }
464 }
465}
466
467impl From<&rpc::Session> for Session {
468 fn from(ssn: &rpc::Session) -> Self {
469 let metadata = ssn.metadata.clone().unwrap();
470 let status = ssn.status.clone().unwrap();
471 let spec = ssn.spec.clone().unwrap();
472
473 let naivedatetime_utc =
474 DateTime::from_timestamp_millis(status.creation_time * 1000).unwrap();
475 let creation_time = Utc.from_utc_datetime(&naivedatetime_utc.naive_utc());
476
477 Session {
478 client: None,
479 id: metadata.id,
480 slots: spec.slots,
481 application: spec.application,
482 creation_time,
483 state: SessionState::try_from(status.state).unwrap_or(SessionState::default()),
484 pending: status.pending,
485 running: status.running,
486 succeed: status.succeed,
487 failed: status.failed,
488 events: status.events.clone().into_iter().map(Event::from).collect(),
489 tasks: None,
490 }
491 }
492}
493
494impl From<&rpc::Event> for Event {
495 fn from(event: &rpc::Event) -> Self {
496 let second = event.creation_time / 1000;
497 let nanosecond = ((event.creation_time % 1000) * 1_000_000) as u32;
498
499 Self {
500 code: event.code,
501 message: event.message.clone(),
502 creation_time: DateTime::from_timestamp(second, nanosecond).unwrap(),
503 }
504 }
505}
506
507impl From<rpc::Event> for Event {
508 fn from(event: rpc::Event) -> Self {
509 Event::from(&event)
510 }
511}
512
513impl From<&rpc::Application> for Application {
514 fn from(app: &rpc::Application) -> Self {
515 let metadata = app.metadata.clone().unwrap();
516 let spec = app.spec.clone().unwrap();
517 let status = app.status.unwrap();
518
519 let naivedatetime_utc =
520 DateTime::from_timestamp_millis(status.creation_time * 1000).unwrap();
521 let creation_time = Utc.from_utc_datetime(&naivedatetime_utc.naive_utc());
522
523 Self {
524 name: metadata.name,
525 attributes: ApplicationAttributes::from(spec),
526 state: ApplicationState::from(status.state()),
527 creation_time,
528 }
529 }
530}
531
532impl From<ApplicationAttributes> for ApplicationSpec {
533 fn from(app: ApplicationAttributes) -> Self {
534 Self {
535 shim: app.shim.into(),
536 image: app.image.clone(),
537 description: app.description.clone(),
538 labels: app.labels.clone(),
539 command: app.command.clone(),
540 arguments: app.arguments.clone(),
541 environments: app
542 .environments
543 .clone()
544 .into_iter()
545 .map(|(key, value)| Environment { name: key, value })
546 .collect(),
547 working_directory: app.working_directory.clone(),
548 max_instances: app.max_instances,
549 delay_release: app.delay_release.map(|s| s.num_seconds()),
550 schema: app.schema.clone().map(rpc::ApplicationSchema::from),
551 }
552 }
553}
554
555impl From<ApplicationSpec> for ApplicationAttributes {
556 fn from(app: ApplicationSpec) -> Self {
557 Self {
558 shim: app.shim().into(),
559 image: app.image.clone(),
560 description: app.description.clone(),
561 labels: app.labels.clone(),
562 command: app.command.clone(),
563 arguments: app.arguments.clone(),
564 environments: app
565 .environments
566 .clone()
567 .into_iter()
568 .map(|env| (env.name, env.value))
569 .collect(),
570 working_directory: app.working_directory.clone(),
571 max_instances: app.max_instances,
572 delay_release: app.delay_release.map(Duration::seconds),
573 schema: app.schema.clone().map(ApplicationSchema::from),
574 }
575 }
576}
577
578impl From<ApplicationSchema> for rpc::ApplicationSchema {
579 fn from(schema: ApplicationSchema) -> Self {
580 Self {
581 input: schema.input,
582 output: schema.output,
583 common_data: schema.common_data,
584 }
585 }
586}
587
588impl From<rpc::ApplicationSchema> for ApplicationSchema {
589 fn from(schema: rpc::ApplicationSchema) -> Self {
590 Self {
591 input: schema.input,
592 output: schema.output,
593 common_data: schema.common_data,
594 }
595 }
596}
597
598impl From<&rpc::Executor> for Executor {
599 fn from(e: &rpc::Executor) -> Self {
600 let spec = e.spec.clone().unwrap();
601 let status = e.status.clone().unwrap();
602 let metadata = e.metadata.clone().unwrap();
603
604 let state = rpc::ExecutorState::try_from(status.state).unwrap().into();
605
606 Executor {
607 id: metadata.id,
608 session_id: status.session_id,
609 slots: spec.slots,
610 node: spec.node,
611 state,
612 }
613 }
614}
615
616impl From<rpc::Executor> for Executor {
617 fn from(e: rpc::Executor) -> Self {
618 Executor::from(&e)
619 }
620}
621
622mod serde_duration {
623 use chrono::Duration;
624 use serde::{Deserialize, Deserializer, Serializer};
625
626 pub fn serialize<S>(duration: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
627 where
628 S: Serializer,
629 {
630 match duration {
631 Some(duration) => serializer.serialize_i64(duration.num_seconds()),
632 None => serializer.serialize_none(),
633 }
634 }
635
636 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
637 where
638 D: Deserializer<'de>,
639 {
640 let seconds = i64::deserialize(deserializer)?;
641 Ok(Some(Duration::seconds(seconds)))
642 }
643}
644
645mod serde_utc {
646 use chrono::{DateTime, Utc};
647 use serde::{self, Deserialize, Deserializer, Serializer};
648
649 pub fn serialize<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
650 where
651 S: Serializer,
652 {
653 serializer.serialize_i64(date.timestamp())
654 }
655
656 pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
657 where
658 D: Deserializer<'de>,
659 {
660 let timestamp = i64::deserialize(deserializer)?;
661 DateTime::<Utc>::from_timestamp(timestamp, 0)
662 .ok_or(serde::de::Error::custom("invalid timestamp"))
663 }
664}
665
666mod serde_message {
667 use bytes::Bytes;
668 use serde::{Deserialize, Deserializer, Serializer};
669
670 pub fn serialize<S>(message: &Option<Bytes>, serializer: S) -> Result<S::Ok, S::Error>
671 where
672 S: Serializer,
673 {
674 match message {
675 Some(message) => serializer.serialize_bytes(message),
676 None => serializer.serialize_none(),
677 }
678 }
679
680 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Bytes>, D::Error>
681 where
682 D: Deserializer<'de>,
683 {
684 let bytes = Vec::<u8>::deserialize(deserializer)?;
685 Ok(Some(Bytes::from(bytes)))
686 }
687}