1use std::time::Duration;
2
3use bon::Builder;
4use communication::{APICommand, WorkerGrpcBackend};
5use config::CONFIG;
6use malstrom::{
7 coordinator::{Coordinator, CoordinatorExecutionError},
8 runtime::{CommunicationError, RuntimeFlavor},
9 snapshot::PersistenceBackend,
10 types::WorkerId,
11 worker::{StreamProvider, WorkerBuilder, WorkerExecutionError},
12};
13use thiserror::Error;
14mod communication;
15mod config;
16use crate::communication::CoordinatorGrpcBackend;
17
18#[derive(Builder)]
19pub struct KubernetesRuntime<P> {
20 #[builder(finish_fn)]
21 build: fn(&mut dyn StreamProvider) -> (),
22 persistence: P,
23 snapshots: Option<Duration>,
24}
25
26impl<P> KubernetesRuntime<P>
27where
28 P: PersistenceBackend + Clone + Send + Sync,
29{
30 pub fn execute_auto(self) -> Result<(), ExecuteAutoError> {
34 if CONFIG.is_coordinator {
35 self.execute_coordinator()?;
36 } else {
37 self.execute_worker()?;
38 };
39 Ok(())
40 }
41
42 pub fn execute_worker(self) -> Result<(), WorkerExecutionError> {
44 let rt = tokio::runtime::Builder::new_multi_thread()
45 .worker_threads((CONFIG.initial_scale * 2) as usize)
46 .enable_all()
47 .build()
48 .unwrap();
49 let mut worker = WorkerBuilder::new(
50 KubernetesRuntimeFlavor(rt.handle().clone()),
51 self.persistence,
52 );
53 (self.build)(&mut worker);
54 worker.execute()
55 }
56
57 pub fn execute_coordinator(self) -> Result<(), CoordinatorExecutionError> {
59 let rt = tokio::runtime::Builder::new_multi_thread()
61 .worker_threads(4)
64 .enable_all()
65 .build()
66 .unwrap();
67 let (api_tx, api_rx) = flume::unbounded();
69
70 let communication = CoordinatorGrpcBackend::new(api_tx, rt.handle().clone()).unwrap();
71 let (coordinator, coordinator_api) = Coordinator::new();
72
73 rt.spawn_blocking(move || {
74 coordinator
75 .execute(
76 CONFIG.initial_scale,
77 self.snapshots,
78 self.persistence,
79 communication,
80 )
81 .unwrap()
82 });
83
84 let api_thread = rt.spawn(async move {
85 while let Ok(req) = api_rx.recv() {
86 match req {
87 APICommand::Rescale(rescale_command) => {
88 coordinator_api
90 .rescale(rescale_command.desired)
91 .await
92 .unwrap();
93 let _ = rescale_command.on_finish.send(());
94 }
95 }
96 }
97 });
98
99 rt.block_on(api_thread).unwrap();
100
101 Ok(())
102 }
103}
104
105#[derive(Debug, Error)]
106pub enum ExecuteAutoError {
107 #[error(transparent)]
108 WorkerBuild(#[from] WorkerExecutionError),
109 #[error(transparent)]
110 CoordinatorCreate(#[from] CoordinatorExecutionError),
111}
112
113pub struct KubernetesRuntimeFlavor(tokio::runtime::Handle);
114
115impl RuntimeFlavor for KubernetesRuntimeFlavor {
116 type Communication = WorkerGrpcBackend;
117
118 fn communication(&mut self) -> Result<Self::Communication, CommunicationError> {
119 WorkerGrpcBackend::new(self.0.clone()).map_err(CommunicationError::from_error)
120 }
121
122 fn this_worker_id(&self) -> WorkerId {
123 crate::config::CONFIG.get_worker_id()
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use crate::communication::transport::GrpcTransport;
130
131 use super::*;
132 use malstrom::runtime::communication::BiStreamTransport;
133 use tokio::net::TcpListener;
134 use tokio_stream::wrappers::TcpListenerStream;
135 use tonic::transport::Endpoint;
136
137 async fn get_listener_stream() -> (TcpListenerStream, Endpoint) {
138 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
140 let addr = listener.local_addr().unwrap();
141 let stream = TcpListenerStream::new(listener);
142 let endpoint = Endpoint::try_from(format!("http://{addr}")).unwrap();
143 (stream, endpoint)
144 }
145
146 #[tokio::test]
148 async fn test_message_coordinator_to_worker() {
149 let rt = tokio::runtime::Handle::current();
150
151 let (coord_incoming, coord_endpoint) = get_listener_stream().await;
152 let (worker_incoming, worker_endpoint) = get_listener_stream().await;
153
154 let (command_tx, _command_rx) = flume::unbounded();
155 let coord_backend =
156 CoordinatorGrpcBackend::new_with_incoming(rt.clone(), command_tx, coord_incoming)
157 .unwrap();
158 let worker_backend =
159 WorkerGrpcBackend::new_with_incoming(rt.clone(), worker_incoming).unwrap();
160
161 let coord_transport = GrpcTransport::coordinator_worker(&coord_backend, 0, worker_endpoint);
162 let worker_transport = GrpcTransport::worker_coordinator(&worker_backend, coord_endpoint);
163
164 let msg = vec![0, 0, 8, 1, 2];
165 coord_transport.send(msg.clone()).unwrap();
166 let received = worker_transport.recv_async().await.unwrap();
167
168 assert_eq!(msg, received)
169 }
170
171 #[tokio::test]
173 async fn test_message_worker_to_coordinator() {
174 let rt = tokio::runtime::Handle::current();
175
176 let (coord_incoming, coord_endpoint) = get_listener_stream().await;
177 let (worker_incoming, worker_endpoint) = get_listener_stream().await;
178
179 let (command_tx, _command_rx) = flume::unbounded();
180 let coord_backend =
181 CoordinatorGrpcBackend::new_with_incoming(rt.clone(), command_tx, coord_incoming)
182 .unwrap();
183 let worker_backend =
184 WorkerGrpcBackend::new_with_incoming(rt.clone(), worker_incoming).unwrap();
185
186 let coord_transport = GrpcTransport::coordinator_worker(&coord_backend, 0, worker_endpoint);
187 let worker_transport = GrpcTransport::worker_coordinator(&worker_backend, coord_endpoint);
188
189 let msg = vec![5, 5, 8, 1, 2];
190 worker_transport.send(msg.clone()).unwrap();
191 let received = coord_transport.recv_async().await.unwrap();
192
193 assert_eq!(msg, received)
194 }
195
196 #[tokio::test]
198 async fn test_message_coordinator_to_unconnected() {
199 let rt = tokio::runtime::Handle::current();
200
201 let (coord_incoming, _) = get_listener_stream().await;
202
203 let (command_tx, _command_rx) = flume::unbounded();
204 let coord_backend =
205 CoordinatorGrpcBackend::new_with_incoming(rt.clone(), command_tx, coord_incoming)
206 .unwrap();
207
208 let fake_endpoint = Endpoint::try_from(format!("http://127.0.0.1:99999")).unwrap();
209 let coord_transport = GrpcTransport::coordinator_worker(&coord_backend, 0, fake_endpoint);
210 let msg = vec![5, 5, 8, 1, 2];
211 coord_transport.send(msg.clone()).unwrap();
212 }
213}