1use std::{
5 collections::{BTreeMap, BTreeSet},
6 sync::Arc,
7};
8
9use futures::{lock::Mutex, stream::StreamExt, FutureExt};
10use linera_base::{
11 data_types::TimeDelta,
12 identifiers::{ApplicationId, ChainId},
13};
14use linera_client::chain_listener::{ClientContext, ListenerCommand};
15use linera_core::{client::ChainClient, node::NotificationStream, worker::Reason};
16use linera_sdk::abis::controller::{LocalWorkerState, Operation, WorkerCommand};
17use serde_json::json;
18use tokio::{
19 select,
20 sync::mpsc::{self, UnboundedSender},
21};
22use tokio_util::sync::CancellationToken;
23use tracing::{debug, error, info};
24
25use crate::task_processor::{OperatorMap, TaskProcessor};
26
27#[derive(Debug)]
29pub struct Update {
30 pub application_ids: Vec<ApplicationId>,
31}
32
33struct ProcessorHandle {
34 update_sender: mpsc::UnboundedSender<Update>,
35}
36
37pub struct Controller<Ctx: ClientContext> {
38 chain_id: ChainId,
39 controller_id: ApplicationId,
40 context: Arc<Mutex<Ctx>>,
41 chain_client: ChainClient<Ctx::Environment>,
42 cancellation_token: CancellationToken,
43 notifications: NotificationStream,
44 operators: OperatorMap,
45 retry_delay: TimeDelta,
46 processors: BTreeMap<ChainId, ProcessorHandle>,
47 listened_local_chains: BTreeSet<ChainId>,
48 command_sender: UnboundedSender<ListenerCommand>,
49}
50
51impl<Ctx> Controller<Ctx>
52where
53 Ctx: ClientContext + Send + Sync + 'static,
54 Ctx::Environment: 'static,
55 <Ctx::Environment as linera_core::Environment>::Storage: Clone,
56{
57 #[allow(clippy::too_many_arguments)]
58 pub fn new(
59 chain_id: ChainId,
60 controller_id: ApplicationId,
61 context: Arc<Mutex<Ctx>>,
62 chain_client: ChainClient<Ctx::Environment>,
63 cancellation_token: CancellationToken,
64 operators: OperatorMap,
65 retry_delay: TimeDelta,
66 command_sender: UnboundedSender<ListenerCommand>,
67 ) -> Self {
68 let notifications = chain_client.subscribe().expect("client subscription");
69 Self {
70 chain_id,
71 controller_id,
72 context,
73 chain_client,
74 cancellation_token,
75 notifications,
76 operators,
77 retry_delay,
78 processors: BTreeMap::new(),
79 listened_local_chains: BTreeSet::new(),
80 command_sender,
81 }
82 }
83
84 pub async fn run(mut self) {
85 info!(
86 "Watching for notifications for controller chain {}",
87 self.chain_id
88 );
89 self.process_controller_state().await;
90 loop {
91 select! {
92 Some(notification) = self.notifications.next() => {
93 if let Reason::NewBlock { .. } = notification.reason {
94 debug!("Processing notification on controller chain {}", self.chain_id);
95 self.process_controller_state().await;
96 }
97 }
98 _ = self.cancellation_token.cancelled().fuse() => {
99 break;
100 }
101 }
102 }
103 debug!("Notification stream ended.");
104 }
105
106 async fn process_controller_state(&mut self) {
107 let state = match self.query_controller_state().await {
108 Ok(state) => state,
109 Err(error) => {
110 error!("Error reading controller state: {error}");
111 return;
112 }
113 };
114 let Some(worker) = state.local_worker else {
115 self.register_worker().await;
117 return;
118 };
119 assert_eq!(
120 worker.owner,
121 self.chain_client
122 .preferred_owner()
123 .expect("The current wallet should own the chain being watched"),
124 "We should be registered with the current account owner."
125 );
126
127 let mut chain_apps: BTreeMap<ChainId, Vec<ApplicationId>> = BTreeMap::new();
129 for service in &state.local_services {
130 chain_apps
131 .entry(service.chain_id)
132 .or_default()
133 .push(service.application_id);
134 }
135
136 let old_chains: BTreeSet<_> = self.processors.keys().cloned().collect();
137
138 for (service_chain_id, application_ids) in chain_apps {
140 if let Err(err) = self
141 .update_or_spawn_processor(service_chain_id, application_ids)
142 .await
143 {
144 error!("Error updating or spawning processor: {err}");
145 return;
146 }
147 }
148
149 let active_chains: std::collections::BTreeSet<_> =
152 state.local_services.iter().map(|s| s.chain_id).collect();
153 let stale_chains: BTreeSet<_> = self
154 .processors
155 .keys()
156 .filter(|chain_id| !active_chains.contains(chain_id))
157 .cloned()
158 .collect();
159 for chain_id in &stale_chains {
160 if let Some(handle) = self.processors.get(chain_id) {
161 let update = Update {
162 application_ids: Vec::new(),
163 };
164 if handle.update_sender.send(update).is_err() {
165 self.processors.remove(chain_id);
167 }
168 }
169 }
170
171 let local_chains: BTreeSet<_> = state.local_chains.iter().cloned().collect();
173
174 let old_listened: BTreeSet<_> = old_chains
176 .union(&self.listened_local_chains)
177 .cloned()
178 .collect();
179
180 let desired_listened: BTreeSet<_> = active_chains.union(&local_chains).cloned().collect();
182
183 let owner = worker.owner;
185 let new_chains: BTreeMap<_, _> = desired_listened
186 .difference(&old_listened)
187 .map(|chain_id| (*chain_id, Some(owner)))
188 .collect();
189
190 let chains_to_stop: BTreeSet<_> = old_listened
192 .difference(&desired_listened)
193 .cloned()
194 .collect();
195
196 self.listened_local_chains = local_chains.difference(&active_chains).cloned().collect();
199
200 if let Err(error) = self.command_sender.send(ListenerCommand::SetMessagePolicy(
201 state.local_message_policy,
202 )) {
203 error!(%error, "error sending a command to chain listener");
204 }
205 if let Err(error) = self
206 .command_sender
207 .send(ListenerCommand::Listen(new_chains))
208 {
209 error!(%error, "error sending a command to chain listener");
210 }
211 if let Err(error) = self
212 .command_sender
213 .send(ListenerCommand::StopListening(chains_to_stop))
214 {
215 error!(%error, "error sending a command to chain listener");
216 }
217 }
218
219 async fn register_worker(&mut self) {
220 let capabilities = self.operators.keys().cloned().collect();
221 let command = WorkerCommand::RegisterWorker { capabilities };
222 let owner = self
223 .chain_client
224 .preferred_owner()
225 .expect("The current wallet should own the chain being watched");
226 let bytes =
227 bcs::to_bytes(&Operation::ExecuteWorkerCommand { owner, command }).expect("bcs bytes");
228 let operation = linera_execution::Operation::User {
229 application_id: self.controller_id,
230 bytes,
231 };
232 if let Err(e) = self
233 .chain_client
234 .execute_operations(vec![operation], vec![])
235 .await
236 {
237 error!("Failed to execute worker on-chain registration: {e}");
239 }
240 }
241
242 async fn update_or_spawn_processor(
243 &mut self,
244 service_chain_id: ChainId,
245 application_ids: Vec<ApplicationId>,
246 ) -> Result<(), anyhow::Error> {
247 if let Some(handle) = self.processors.get(&service_chain_id) {
248 let update = Update {
250 application_ids: application_ids.clone(),
251 };
252 if handle.update_sender.send(update).is_err() {
253 self.processors.remove(&service_chain_id);
255 self.spawn_processor(service_chain_id, application_ids)
256 .await?;
257 }
258 } else {
259 self.spawn_processor(service_chain_id, application_ids)
261 .await?;
262 }
263 Ok(())
264 }
265
266 async fn spawn_processor(
267 &mut self,
268 service_chain_id: ChainId,
269 application_ids: Vec<ApplicationId>,
270 ) -> Result<(), anyhow::Error> {
271 info!(
272 "Spawning TaskProcessor for chain {} with applications {:?}",
273 service_chain_id, application_ids
274 );
275
276 let (update_sender, update_receiver) = mpsc::unbounded_channel();
277
278 let mut chain_client = self
279 .context
280 .lock()
281 .await
282 .make_chain_client(service_chain_id)
283 .await?;
284 if let Some(owner) = self.chain_client.preferred_owner() {
287 chain_client.set_preferred_owner(owner);
288 }
289 let processor = TaskProcessor::new(
290 service_chain_id,
291 application_ids,
292 chain_client,
293 self.cancellation_token.child_token(),
294 self.operators.clone(),
295 self.retry_delay,
296 Some(update_receiver),
297 );
298
299 tokio::spawn(processor.run());
300
301 self.processors
302 .insert(service_chain_id, ProcessorHandle { update_sender });
303
304 Ok(())
305 }
306
307 async fn query_controller_state(&mut self) -> Result<LocalWorkerState, anyhow::Error> {
308 let query = "query { localWorkerState }";
309 let bytes = serde_json::to_vec(&json!({"query": query}))?;
310 let query = linera_execution::Query::User {
311 application_id: self.controller_id,
312 bytes,
313 };
314 let (
315 linera_execution::QueryOutcome {
316 response,
317 operations: _,
318 },
319 _,
320 ) = self.chain_client.query_application(query, None).await?;
321 let linera_execution::QueryResponse::User(response) = response else {
322 anyhow::bail!("cannot get a system response for a user query");
323 };
324 let mut response: serde_json::Value = serde_json::from_slice(&response)?;
325 let state = serde_json::from_value(response["data"]["localWorkerState"].take())?;
326 Ok(state)
327 }
328}