1use std::collections::HashMap;
21use std::io::{Read, Write};
22use std::net::{TcpListener, TcpStream};
23use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
24use std::sync::{Arc, Mutex};
25use std::thread;
26use std::time::Instant;
27
28use crate::agent::{
29 frame_kind, AgentHello, AgentHelloAck, AgentState, FireCommand, WorkloadType,
30 AGENT_PROTO_VERSION,
31};
32
33struct ConnectedAgent {
35 stream: TcpStream,
36 hostname: String,
37 cores: usize,
38 memory_bytes: u64,
39 agent_name: Option<String>,
40 state: AgentState,
41 #[allow(dead_code)]
42 session_id: u64,
43 connected_at: Instant,
44}
45
46pub struct Controller {
48 agents: Arc<Mutex<HashMap<u64, ConnectedAgent>>>,
49 next_session_id: AtomicU64,
50 running: AtomicBool,
51}
52
53impl Default for Controller {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl Controller {
60 pub fn new() -> Self {
61 Self {
62 agents: Arc::new(Mutex::new(HashMap::new())),
63 next_session_id: AtomicU64::new(1),
64 running: AtomicBool::new(true),
65 }
66 }
67
68 fn accept_loop(&self, listener: TcpListener) {
70 for stream in listener.incoming() {
71 if !self.running.load(Ordering::Relaxed) {
72 break;
73 }
74
75 match stream {
76 Ok(mut stream) => {
77 let session_id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
78
79 let (kind, payload) = match read_frame(&mut stream) {
81 Ok(f) => f,
82 Err(e) => {
83 eprintln!("controller: failed to read hello: {}", e);
84 continue;
85 }
86 };
87
88 if kind != frame_kind::AGENT_HELLO {
89 eprintln!("controller: expected AGENT_HELLO, got {}", kind);
90 continue;
91 }
92
93 let hello: AgentHello = match bincode::deserialize(&payload) {
94 Ok(h) => h,
95 Err(e) => {
96 eprintln!("controller: invalid hello: {}", e);
97 continue;
98 }
99 };
100
101 if hello.proto_version != AGENT_PROTO_VERSION {
102 let ack = AgentHelloAck {
103 session_id: 0,
104 accepted: false,
105 message: format!(
106 "protocol version mismatch: got {}, expected {}",
107 hello.proto_version, AGENT_PROTO_VERSION
108 ),
109 };
110 let ack_bytes = bincode::serialize(&ack).unwrap();
111 let _ = write_frame(&mut stream, frame_kind::AGENT_HELLO_ACK, &ack_bytes);
112 continue;
113 }
114
115 let ack = AgentHelloAck {
117 session_id,
118 accepted: true,
119 message: "connected".to_string(),
120 };
121 let ack_bytes = bincode::serialize(&ack).unwrap();
122 if let Err(e) =
123 write_frame(&mut stream, frame_kind::AGENT_HELLO_ACK, &ack_bytes)
124 {
125 eprintln!("controller: failed to send hello ack: {}", e);
126 continue;
127 }
128
129 let name = hello
130 .agent_name
131 .clone()
132 .unwrap_or_else(|| hello.hostname.clone());
133 eprintln!(
134 "[agent connected] {} (cores={}, session={})",
135 name, hello.cores, session_id
136 );
137
138 let agent = ConnectedAgent {
139 stream,
140 hostname: hello.hostname,
141 cores: hello.cores,
142 memory_bytes: hello.memory_bytes,
143 agent_name: hello.agent_name,
144 state: AgentState::Idle,
145 session_id,
146 connected_at: Instant::now(),
147 };
148
149 self.agents.lock().unwrap().insert(session_id, agent);
150 }
151 Err(e) => {
152 if self.running.load(Ordering::Relaxed) {
153 eprintln!("controller: accept error: {}", e);
154 }
155 }
156 }
157 }
158 }
159
160 fn fire_all(&self, duration_secs: f64) {
162 let cmd = FireCommand {
163 workload: WorkloadType::Cpu,
164 duration_secs,
165 intensity: 1.0,
166 };
167 let cmd_bytes = bincode::serialize(&cmd).unwrap();
168
169 let mut agents = self.agents.lock().unwrap();
170 let mut fired = 0;
171
172 for agent in agents.values_mut() {
173 if write_frame(&mut agent.stream, frame_kind::FIRE, &cmd_bytes).is_ok() {
174 agent.state = AgentState::Firing;
175 fired += 1;
176 }
177 }
178
179 eprintln!("[fire] {} agents, duration={}s", fired, duration_secs);
180 }
181
182 fn terminate_all(&self) {
184 let mut agents = self.agents.lock().unwrap();
185 let mut terminated = 0;
186
187 for agent in agents.values_mut() {
188 if write_frame(&mut agent.stream, frame_kind::TERMINATE, &[]).is_ok() {
189 agent.state = AgentState::Terminated;
190 terminated += 1;
191 }
192 }
193
194 eprintln!("[terminate] {} agents", terminated);
195 }
196
197 fn print_status(&self) {
199 let agents = self.agents.lock().unwrap();
200
201 if agents.is_empty() {
202 println!("No agents connected.");
203 return;
204 }
205
206 println!(
207 "{:<20} {:>6} {:>10} {:>12} {:>10}",
208 "AGENT", "CORES", "MEMORY", "STATE", "UPTIME"
209 );
210 println!("{}", "-".repeat(62));
211
212 for agent in agents.values() {
213 let name = agent
214 .agent_name
215 .clone()
216 .unwrap_or_else(|| agent.hostname.clone());
217 let mem_gb = agent.memory_bytes / (1024 * 1024 * 1024);
218 let state = match agent.state {
219 AgentState::Idle => "idle",
220 AgentState::Armed => "armed",
221 AgentState::Firing => "FIRING",
222 AgentState::Terminated => "terminated",
223 };
224 let uptime = agent.connected_at.elapsed().as_secs();
225
226 println!(
227 "{:<20} {:>6} {:>8}GB {:>12} {:>8}s",
228 name, agent.cores, mem_gb, state, uptime
229 );
230 }
231
232 let total_cores: usize = agents.values().map(|a| a.cores).sum();
233 let firing_count = agents
234 .values()
235 .filter(|a| a.state == AgentState::Firing)
236 .count();
237
238 println!();
239 println!(
240 "Total: {} agents, {} cores, {} firing",
241 agents.len(),
242 total_cores,
243 firing_count
244 );
245 }
246
247 fn shutdown_all(&self) {
249 let mut agents = self.agents.lock().unwrap();
250
251 for agent in agents.values_mut() {
252 let _ = write_frame(&mut agent.stream, frame_kind::SHUTDOWN, &[]);
253 }
254
255 agents.clear();
256 self.running.store(false, Ordering::Relaxed);
257 eprintln!("[shutdown] all agents disconnected");
258 }
259
260 fn run_repl(&self) {
262 use std::io::{stdin, BufRead};
263
264 println!("stryke controller v{}", env!("CARGO_PKG_VERSION"));
265 println!("Type 'help' for commands, Ctrl-C to exit\n");
266
267 let stdin = stdin();
268 for line in stdin.lock().lines() {
269 let line = match line {
270 Ok(l) => l,
271 Err(_) => break,
272 };
273
274 let parts: Vec<&str> = line.split_whitespace().collect();
275 if parts.is_empty() {
276 continue;
277 }
278
279 match parts[0] {
280 "status" | "s" => self.print_status(),
281 "fire" | "f" => {
282 let duration = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(10.0);
283 self.fire_all(duration);
284 }
285 "terminate" | "t" | "stop" => self.terminate_all(),
286 "shutdown" | "quit" | "exit" | "q" => {
287 self.shutdown_all();
288 break;
289 }
290 "help" | "h" | "?" => {
291 println!("Commands:");
292 println!(" status (s) List connected agents");
293 println!(" fire [SECS] (f) Start stress test (default: 10s)");
294 println!(" terminate (t) Stop stress test");
295 println!(" shutdown (q) Disconnect all and exit");
296 println!(" help (h) Show this help");
297 }
298 _ => println!("Unknown command: {}. Type 'help' for commands.", parts[0]),
299 }
300 }
301 }
302}
303
304fn read_frame<R: Read>(r: &mut R) -> std::io::Result<(u8, Vec<u8>)> {
306 let mut len_buf = [0u8; 8];
307 r.read_exact(&mut len_buf)?;
308 let len = u64::from_le_bytes(len_buf) as usize;
309 if len < 1 {
310 return Err(std::io::Error::new(
311 std::io::ErrorKind::InvalidData,
312 "empty frame",
313 ));
314 }
315 let mut payload = vec![0u8; len];
316 r.read_exact(&mut payload)?;
317 let kind = payload[0];
318 Ok((kind, payload[1..].to_vec()))
319}
320
321fn write_frame<W: Write>(w: &mut W, kind: u8, payload: &[u8]) -> std::io::Result<()> {
323 let total_len = 1 + payload.len();
324 w.write_all(&(total_len as u64).to_le_bytes())?;
325 w.write_all(&[kind])?;
326 w.write_all(payload)?;
327 w.flush()
328}
329
330pub fn run_controller(bind: &str, port: u16) -> i32 {
332 let addr = format!("{}:{}", bind, port);
333
334 let listener = match TcpListener::bind(&addr) {
335 Ok(l) => l,
336 Err(e) => {
337 eprintln!("controller: cannot bind to {}: {}", addr, e);
338 return 1;
339 }
340 };
341
342 eprintln!("stryke controller listening on {}", addr);
343 eprintln!("Waiting for agents...\n");
344
345 let controller = Arc::new(Controller::new());
346
347 let ctrl_clone = Arc::clone(&controller);
349 let accept_handle = thread::spawn(move || {
350 ctrl_clone.accept_loop(listener);
351 });
352
353 controller.run_repl();
355
356 controller.running.store(false, Ordering::Relaxed);
358 let _ = accept_handle.join();
359
360 0
361}
362
363pub fn print_help() {
365 println!("stryke controller — Distributed load testing controller");
366 println!();
367 println!("USAGE:");
368 println!(" stryke controller [OPTIONS]");
369 println!();
370 println!("OPTIONS:");
371 println!(" --bind ADDR Bind address (default: 0.0.0.0)");
372 println!(" --port PORT Listen port (default: 9999)");
373 println!(" --help Print this help");
374 println!();
375 println!("COMMANDS (in REPL):");
376 println!(" status List connected agents");
377 println!(" fire [SECS] Start stress test (default: 10 seconds)");
378 println!(" terminate Stop stress test");
379 println!(" shutdown Disconnect all agents and exit");
380 println!();
381 println!("EXAMPLE:");
382 println!(" stryke controller --port 9999");
383 println!();
384 println!(" controller> status");
385 println!(" controller> fire 60 # 60 second stress test");
386 println!(" controller> terminate");
387}