inference_runtime/
dp_coordinator.rs1use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use parking_lot::RwLock;
15use rakka_core::actor::{Actor, Context, UntypedActorRef};
16use tokio::sync::oneshot;
17
18use inference_core::error::InferenceError;
19
20#[derive(Clone)]
21pub struct RouteTarget {
22 pub engine: UntypedActorRef,
25 pub load: f64,
28}
29
30impl std::fmt::Debug for RouteTarget {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("RouteTarget")
33 .field("load", &self.load)
34 .finish_non_exhaustive()
35 }
36}
37
38#[derive(Clone, Default)]
39struct CoordinatorState {
40 routes: HashMap<String, Vec<RouteTarget>>,
41}
42
43pub enum DpCoordinatorMsg {
44 Register {
45 deployment: String,
46 target: RouteTarget,
47 },
48 Deregister {
49 deployment: String,
50 engine_path: rakka_core::actor::ActorPath,
51 },
52 ReportLoad {
53 deployment: String,
54 engine_path: rakka_core::actor::ActorPath,
55 load: f64,
56 },
57 RouteTo {
58 deployment: String,
59 reply: oneshot::Sender<Result<RouteTarget, InferenceError>>,
60 },
61}
62
63pub struct DpCoordinatorActor {
64 state: Arc<RwLock<CoordinatorState>>,
65}
66
67impl Default for DpCoordinatorActor {
68 fn default() -> Self {
69 Self {
70 state: Arc::new(RwLock::new(CoordinatorState::default())),
71 }
72 }
73}
74
75impl DpCoordinatorActor {
76 pub fn new() -> Self {
77 Self::default()
78 }
79
80 fn register(&self, deployment: String, target: RouteTarget) {
81 self.state
82 .write()
83 .routes
84 .entry(deployment)
85 .or_default()
86 .push(target);
87 }
88
89 fn deregister(&self, deployment: &str, path: &rakka_core::actor::ActorPath) {
90 if let Some(v) = self.state.write().routes.get_mut(deployment) {
91 v.retain(|t| t.engine.path() != path);
92 }
93 }
94
95 fn report_load(&self, deployment: &str, path: &rakka_core::actor::ActorPath, load: f64) {
96 if let Some(v) = self.state.write().routes.get_mut(deployment) {
97 for t in v.iter_mut() {
98 if t.engine.path() == path {
99 t.load = load;
100 }
101 }
102 }
103 }
104
105 fn pick(&self, deployment: &str) -> Result<RouteTarget, InferenceError> {
106 let st = self.state.read();
107 let candidates = st
108 .routes
109 .get(deployment)
110 .filter(|v| !v.is_empty())
111 .ok_or_else(|| InferenceError::Internal(format!("no engine for deployment `{deployment}`")))?;
112 let pick = candidates
114 .iter()
115 .min_by(|a, b| a.load.partial_cmp(&b.load).unwrap_or(std::cmp::Ordering::Equal))
116 .cloned()
117 .ok_or_else(|| InferenceError::Internal("empty candidate set".into()))?;
118 Ok(pick)
119 }
120}
121
122#[async_trait]
123impl Actor for DpCoordinatorActor {
124 type Msg = DpCoordinatorMsg;
125
126 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
127 match msg {
128 DpCoordinatorMsg::Register { deployment, target } => self.register(deployment, target),
129 DpCoordinatorMsg::Deregister {
130 deployment,
131 engine_path,
132 } => self.deregister(&deployment, &engine_path),
133 DpCoordinatorMsg::ReportLoad {
134 deployment,
135 engine_path,
136 load,
137 } => self.report_load(&deployment, &engine_path, load),
138 DpCoordinatorMsg::RouteTo { deployment, reply } => {
139 let _ = reply.send(self.pick(&deployment));
140 }
141 }
142 }
143}