Skip to main content

stryke/
controller.rs

1//! `stryke controller` — Interactive REPL for coordinating stress test agents.
2//!
3//! ## Usage
4//!
5//! ```sh
6//! stryke controller                    # listen on 0.0.0.0:9999
7//! stryke controller --port 8888        # custom port
8//! stryke controller --bind 10.0.0.1    # specific interface
9//! ```
10//!
11//! ## Commands
12//!
13//! - `status` — list connected agents
14//! - `fire [duration]` — start stress test on all agents
15//! - `fire node1,node2 [duration]` — specific agents
16//! - `terminate` — stop stress test
17//! - `shutdown` — disconnect all agents and exit
18//! - `help` — show commands
19
20use std::collections::HashMap;
21use std::io::{Read, Write};
22use std::net::{TcpListener, TcpStream};
23use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::{Arc, Mutex};
25use std::thread;
26use std::time::Instant;
27
28use crate::agent::{
29    frame_kind, AgentHello, AgentHelloAck, AgentState, FireCommand, WorkloadType,
30    AGENT_PROTO_VERSION,
31};
32
33/// Connected agent state
34struct ConnectedAgent {
35    stream: TcpStream,
36    hostname: String,
37    cores: usize,
38    memory_bytes: u64,
39    agent_name: Option<String>,
40    state: AgentState,
41    #[allow(dead_code)]
42    session_id: u64,
43    connected_at: Instant,
44}
45
46/// Controller state
47pub struct Controller {
48    agents: Arc<Mutex<HashMap<u64, ConnectedAgent>>>,
49    next_session_id: AtomicU64,
50    running: AtomicBool,
51}
52
53impl Default for Controller {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl Controller {
60    pub fn new() -> Self {
61        Self {
62            agents: Arc::new(Mutex::new(HashMap::new())),
63            next_session_id: AtomicU64::new(1),
64            running: AtomicBool::new(true),
65        }
66    }
67
68    /// Accept incoming agent connections
69    fn accept_loop(&self, listener: TcpListener) {
70        for stream in listener.incoming() {
71            if !self.running.load(Ordering::Relaxed) {
72                break;
73            }
74
75            match stream {
76                Ok(mut stream) => {
77                    let session_id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
78
79                    // Read AGENT_HELLO
80                    let (kind, payload) = match read_frame(&mut stream) {
81                        Ok(f) => f,
82                        Err(e) => {
83                            eprintln!("controller: failed to read hello: {}", e);
84                            continue;
85                        }
86                    };
87
88                    if kind != frame_kind::AGENT_HELLO {
89                        eprintln!("controller: expected AGENT_HELLO, got {}", kind);
90                        continue;
91                    }
92
93                    let hello: AgentHello = match bincode::deserialize(&payload) {
94                        Ok(h) => h,
95                        Err(e) => {
96                            eprintln!("controller: invalid hello: {}", e);
97                            continue;
98                        }
99                    };
100
101                    if hello.proto_version != AGENT_PROTO_VERSION {
102                        let ack = AgentHelloAck {
103                            session_id: 0,
104                            accepted: false,
105                            message: format!(
106                                "protocol version mismatch: got {}, expected {}",
107                                hello.proto_version, AGENT_PROTO_VERSION
108                            ),
109                        };
110                        let ack_bytes = bincode::serialize(&ack).unwrap();
111                        let _ = write_frame(&mut stream, frame_kind::AGENT_HELLO_ACK, &ack_bytes);
112                        continue;
113                    }
114
115                    // Send HELLO_ACK
116                    let ack = AgentHelloAck {
117                        session_id,
118                        accepted: true,
119                        message: "connected".to_string(),
120                    };
121                    let ack_bytes = bincode::serialize(&ack).unwrap();
122                    if let Err(e) =
123                        write_frame(&mut stream, frame_kind::AGENT_HELLO_ACK, &ack_bytes)
124                    {
125                        eprintln!("controller: failed to send hello ack: {}", e);
126                        continue;
127                    }
128
129                    let name = hello
130                        .agent_name
131                        .clone()
132                        .unwrap_or_else(|| hello.hostname.clone());
133                    eprintln!(
134                        "[agent connected] {} (cores={}, session={})",
135                        name, hello.cores, session_id
136                    );
137
138                    let agent = ConnectedAgent {
139                        stream,
140                        hostname: hello.hostname,
141                        cores: hello.cores,
142                        memory_bytes: hello.memory_bytes,
143                        agent_name: hello.agent_name,
144                        state: AgentState::Idle,
145                        session_id,
146                        connected_at: Instant::now(),
147                    };
148
149                    self.agents.lock().unwrap().insert(session_id, agent);
150                }
151                Err(e) => {
152                    if self.running.load(Ordering::Relaxed) {
153                        eprintln!("controller: accept error: {}", e);
154                    }
155                }
156            }
157        }
158    }
159
160    /// Send FIRE to all agents
161    fn fire_all(&self, duration_secs: f64) {
162        let cmd = FireCommand {
163            workload: WorkloadType::Cpu,
164            duration_secs,
165            intensity: 1.0,
166        };
167        let cmd_bytes = bincode::serialize(&cmd).unwrap();
168
169        let mut agents = self.agents.lock().unwrap();
170        let mut fired = 0;
171
172        for agent in agents.values_mut() {
173            if write_frame(&mut agent.stream, frame_kind::FIRE, &cmd_bytes).is_ok() {
174                agent.state = AgentState::Firing;
175                fired += 1;
176            }
177        }
178
179        eprintln!("[fire] {} agents, duration={}s", fired, duration_secs);
180    }
181
182    /// Send TERMINATE to all agents
183    fn terminate_all(&self) {
184        let mut agents = self.agents.lock().unwrap();
185        let mut terminated = 0;
186
187        for agent in agents.values_mut() {
188            if write_frame(&mut agent.stream, frame_kind::TERMINATE, &[]).is_ok() {
189                agent.state = AgentState::Terminated;
190                terminated += 1;
191            }
192        }
193
194        eprintln!("[terminate] {} agents", terminated);
195    }
196
197    /// Print status of all agents
198    fn print_status(&self) {
199        let agents = self.agents.lock().unwrap();
200
201        if agents.is_empty() {
202            println!("No agents connected.");
203            return;
204        }
205
206        println!(
207            "{:<20} {:>6} {:>10} {:>12} {:>10}",
208            "AGENT", "CORES", "MEMORY", "STATE", "UPTIME"
209        );
210        println!("{}", "-".repeat(62));
211
212        for agent in agents.values() {
213            let name = agent
214                .agent_name
215                .clone()
216                .unwrap_or_else(|| agent.hostname.clone());
217            let mem_gb = agent.memory_bytes / (1024 * 1024 * 1024);
218            let state = match agent.state {
219                AgentState::Idle => "idle",
220                AgentState::Armed => "armed",
221                AgentState::Firing => "FIRING",
222                AgentState::Terminated => "terminated",
223            };
224            let uptime = agent.connected_at.elapsed().as_secs();
225
226            println!(
227                "{:<20} {:>6} {:>8}GB {:>12} {:>8}s",
228                name, agent.cores, mem_gb, state, uptime
229            );
230        }
231
232        let total_cores: usize = agents.values().map(|a| a.cores).sum();
233        let firing_count = agents
234            .values()
235            .filter(|a| a.state == AgentState::Firing)
236            .count();
237
238        println!();
239        println!(
240            "Total: {} agents, {} cores, {} firing",
241            agents.len(),
242            total_cores,
243            firing_count
244        );
245    }
246
247    /// Send SHUTDOWN to all agents
248    fn shutdown_all(&self) {
249        let mut agents = self.agents.lock().unwrap();
250
251        for agent in agents.values_mut() {
252            let _ = write_frame(&mut agent.stream, frame_kind::SHUTDOWN, &[]);
253        }
254
255        agents.clear();
256        self.running.store(false, Ordering::Relaxed);
257        eprintln!("[shutdown] all agents disconnected");
258    }
259
260    /// Run the REPL
261    fn run_repl(&self) {
262        use std::io::{stdin, BufRead};
263
264        println!("stryke controller v{}", env!("CARGO_PKG_VERSION"));
265        println!("Type 'help' for commands, Ctrl-C to exit\n");
266
267        let stdin = stdin();
268        for line in stdin.lock().lines() {
269            let line = match line {
270                Ok(l) => l,
271                Err(_) => break,
272            };
273
274            let parts: Vec<&str> = line.split_whitespace().collect();
275            if parts.is_empty() {
276                continue;
277            }
278
279            match parts[0] {
280                "status" | "s" => self.print_status(),
281                "fire" | "f" => {
282                    let duration = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(10.0);
283                    self.fire_all(duration);
284                }
285                "terminate" | "t" | "stop" => self.terminate_all(),
286                "shutdown" | "quit" | "exit" | "q" => {
287                    self.shutdown_all();
288                    break;
289                }
290                "help" | "h" | "?" => {
291                    println!("Commands:");
292                    println!("  status (s)           List connected agents");
293                    println!("  fire [SECS] (f)      Start stress test (default: 10s)");
294                    println!("  terminate (t)        Stop stress test");
295                    println!("  shutdown (q)         Disconnect all and exit");
296                    println!("  help (h)             Show this help");
297                }
298                _ => println!("Unknown command: {}. Type 'help' for commands.", parts[0]),
299            }
300        }
301    }
302}
303
304/// Read a framed message
305fn read_frame<R: Read>(r: &mut R) -> std::io::Result<(u8, Vec<u8>)> {
306    let mut len_buf = [0u8; 8];
307    r.read_exact(&mut len_buf)?;
308    let len = u64::from_le_bytes(len_buf) as usize;
309    if len < 1 {
310        return Err(std::io::Error::new(
311            std::io::ErrorKind::InvalidData,
312            "empty frame",
313        ));
314    }
315    let mut payload = vec![0u8; len];
316    r.read_exact(&mut payload)?;
317    let kind = payload[0];
318    Ok((kind, payload[1..].to_vec()))
319}
320
321/// Write a framed message
322fn write_frame<W: Write>(w: &mut W, kind: u8, payload: &[u8]) -> std::io::Result<()> {
323    let total_len = 1 + payload.len();
324    w.write_all(&(total_len as u64).to_le_bytes())?;
325    w.write_all(&[kind])?;
326    w.write_all(payload)?;
327    w.flush()
328}
329
330/// Main entry point
331pub fn run_controller(bind: &str, port: u16) -> i32 {
332    let addr = format!("{}:{}", bind, port);
333
334    let listener = match TcpListener::bind(&addr) {
335        Ok(l) => l,
336        Err(e) => {
337            eprintln!("controller: cannot bind to {}: {}", addr, e);
338            return 1;
339        }
340    };
341
342    eprintln!("stryke controller listening on {}", addr);
343    eprintln!("Waiting for agents...\n");
344
345    let controller = Arc::new(Controller::new());
346
347    // Spawn accept thread
348    let ctrl_clone = Arc::clone(&controller);
349    let accept_handle = thread::spawn(move || {
350        ctrl_clone.accept_loop(listener);
351    });
352
353    // Run REPL on main thread
354    controller.run_repl();
355
356    // Cleanup
357    controller.running.store(false, Ordering::Relaxed);
358    let _ = accept_handle.join();
359
360    0
361}
362
363/// Print controller help
364pub fn print_help() {
365    println!("stryke controller — Distributed load testing controller");
366    println!();
367    println!("USAGE:");
368    println!("    stryke controller [OPTIONS]");
369    println!();
370    println!("OPTIONS:");
371    println!("    --bind ADDR          Bind address (default: 0.0.0.0)");
372    println!("    --port PORT          Listen port (default: 9999)");
373    println!("    --help               Print this help");
374    println!();
375    println!("COMMANDS (in REPL):");
376    println!("    status               List connected agents");
377    println!("    fire [SECS]          Start stress test (default: 10 seconds)");
378    println!("    terminate            Stop stress test");
379    println!("    shutdown             Disconnect all agents and exit");
380    println!();
381    println!("EXAMPLE:");
382    println!("    stryke controller --port 9999");
383    println!();
384    println!("    controller> status");
385    println!("    controller> fire 60      # 60 second stress test");
386    println!("    controller> terminate");
387}