Skip to main content

rune_framework/
pilot_client.rs

1use std::path::PathBuf;
2use std::process::Stdio;
3use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7use tokio::net::UnixStream;
8use tokio::process::Command;
9
10use crate::config::ScalePolicy;
11use crate::error::{SdkError, SdkResult};
12
13const PILOT_CONNECTING_ERROR: &str = "runtime session not attached";
14const DEFAULT_PILOT_ENSURE_TIMEOUT: Duration = Duration::from_secs(5);
15const DEFAULT_PILOT_REQUEST_TIMEOUT: Duration = Duration::from_secs(5);
16
17fn pilot_ensure_timeout() -> Duration {
18    std::env::var("RUNE_PILOT_ENSURE_TIMEOUT_SECS")
19        .ok()
20        .and_then(|v| v.parse::<u64>().ok())
21        .map(Duration::from_secs)
22        .unwrap_or(DEFAULT_PILOT_ENSURE_TIMEOUT)
23}
24
25fn pilot_request_timeout() -> Duration {
26    std::env::var("RUNE_PILOT_REQUEST_TIMEOUT_SECS")
27        .ok()
28        .and_then(|v| v.parse::<u64>().ok())
29        .map(Duration::from_secs)
30        .unwrap_or(DEFAULT_PILOT_REQUEST_TIMEOUT)
31}
32
33#[derive(Debug, Serialize)]
34#[serde(tag = "command", rename_all = "snake_case")]
35enum PilotRequest {
36    Register {
37        caster_id: String,
38        pid: u32,
39        group: String,
40        spawn_command: String,
41        shutdown_signal: String,
42    },
43    Deregister {
44        caster_id: String,
45    },
46    Status,
47    Stop,
48}
49
50#[derive(Debug, Deserialize)]
51struct PilotResponse {
52    #[serde(default)]
53    ok: bool,
54    #[serde(default)]
55    pilot_id: String,
56    #[serde(default)]
57    runtime: String,
58    error: Option<String>,
59}
60
61#[derive(Debug, Clone)]
62pub struct PilotClient {
63    pilot_id: String,
64}
65
66impl PilotClient {
67    pub async fn ensure(runtime: &str, key: Option<&str>) -> SdkResult<Self> {
68        let normalized = normalize_runtime(runtime);
69        let deadline = tokio::time::Instant::now() + pilot_ensure_timeout();
70        if let Ok(response) = send_request(&PilotRequest::Status).await {
71            match Self::classify_status(response, &normalized) {
72                EnsureStatus::Ready(client) => return Ok(client),
73                EnsureStatus::Retry => {
74                    return Self::wait_until_ready(&normalized, deadline, Some(runtime), key).await;
75                }
76                EnsureStatus::Mismatch => {
77                    // Existing pilot is bound to a different runtime โ€” stop it first.
78                    let _ = send_request(&PilotRequest::Stop).await;
79                }
80                EnsureStatus::Failed(error) => return Err(SdkError::Other(error)),
81            }
82        }
83
84        start_pilot(runtime, key).await?;
85        Self::wait_until_ready(&normalized, deadline, Some(runtime), key).await
86    }
87
88    /// Poll until the pilot reports ready. When `start_runtime`/`start_key`
89    /// are provided, re-attempt `start_pilot` on connection failure so that
90    /// a slow predecessor release doesn't doom the single initial spawn.
91    async fn wait_until_ready(
92        normalized: &str,
93        deadline: tokio::time::Instant,
94        start_runtime: Option<&str>,
95        start_key: Option<&str>,
96    ) -> SdkResult<Self> {
97        let mut last_start = tokio::time::Instant::now();
98        loop {
99            match send_request(&PilotRequest::Status).await {
100                Ok(response) => match Self::classify_status(response, normalized) {
101                    EnsureStatus::Ready(client) => return Ok(client),
102                    EnsureStatus::Retry | EnsureStatus::Mismatch => {}
103                    EnsureStatus::Failed(error) => return Err(SdkError::Other(error)),
104                },
105                Err(_) => {
106                    // Connection failed โ€” re-attempt start if enough time has passed.
107                    if let Some(rt) = start_runtime {
108                        if last_start.elapsed() >= Duration::from_secs(1) {
109                            let _ = start_pilot(rt, start_key).await;
110                            last_start = tokio::time::Instant::now();
111                        }
112                    }
113                }
114            }
115            if tokio::time::Instant::now() >= deadline {
116                break;
117            }
118            tokio::time::sleep(Duration::from_millis(100)).await;
119        }
120
121        Err(SdkError::Other("pilot did not become ready".into()))
122    }
123
124    pub fn pilot_id(&self) -> &str {
125        &self.pilot_id
126    }
127
128    pub async fn register(&self, caster_id: &str, policy: &ScalePolicy) -> SdkResult<()> {
129        let response = send_request(&PilotRequest::Register {
130            caster_id: caster_id.to_string(),
131            pid: std::process::id(),
132            group: policy.group.clone(),
133            spawn_command: policy.spawn_command.clone(),
134            shutdown_signal: policy.shutdown_signal.clone(),
135        })
136        .await?;
137        Self::ensure_ok(response)
138    }
139
140    pub async fn deregister(&self, caster_id: &str) -> SdkResult<()> {
141        let response = send_request(&PilotRequest::Deregister {
142            caster_id: caster_id.to_string(),
143        })
144        .await?;
145        Self::ensure_ok(response)
146    }
147
148    fn ensure_ok(response: PilotResponse) -> SdkResult<()> {
149        if response.ok {
150            Ok(())
151        } else {
152            Err(SdkError::Other(
153                response
154                    .error
155                    .unwrap_or_else(|| "pilot request failed".into()),
156            ))
157        }
158    }
159
160    fn classify_status(response: PilotResponse, normalized: &str) -> EnsureStatus {
161        if response.runtime != normalized {
162            return EnsureStatus::Mismatch;
163        }
164        if response.ok {
165            return EnsureStatus::Ready(Self {
166                pilot_id: response.pilot_id,
167            });
168        }
169        match response.error {
170            Some(error) if error == PILOT_CONNECTING_ERROR || error.is_empty() => {
171                EnsureStatus::Retry
172            }
173            Some(error) => EnsureStatus::Failed(error),
174            None => EnsureStatus::Retry,
175        }
176    }
177}
178
179enum EnsureStatus {
180    Ready(PilotClient),
181    Retry,
182    Mismatch,
183    Failed(String),
184}
185
186async fn send_request(request: &PilotRequest) -> SdkResult<PilotResponse> {
187    tokio::time::timeout(pilot_request_timeout(), send_request_inner(request))
188        .await
189        .map_err(|_| SdkError::Other("pilot request timed out".into()))?
190}
191
192async fn send_request_inner(request: &PilotRequest) -> SdkResult<PilotResponse> {
193    let socket_path = socket_path()?;
194    let mut stream = UnixStream::connect(&socket_path)
195        .await
196        .map_err(|err| SdkError::Other(format!("failed to connect to pilot: {err}")))?;
197    let payload = serde_json::to_vec(request)
198        .map_err(|err| SdkError::Other(format!("failed to encode pilot request: {err}")))?;
199    stream
200        .write_all(&payload)
201        .await
202        .map_err(|err| SdkError::Other(format!("failed to write pilot request: {err}")))?;
203    stream
204        .shutdown()
205        .await
206        .map_err(|err| SdkError::Other(format!("failed to flush pilot request: {err}")))?;
207    const MAX_RESPONSE_SIZE: u64 = 256 * 1024; // 256KB โ€” symmetric with daemon's 64KB request limit
208    let mut response = Vec::new();
209    stream
210        .take(MAX_RESPONSE_SIZE)
211        .read_to_end(&mut response)
212        .await
213        .map_err(|err| SdkError::Other(format!("failed to read pilot response: {err}")))?;
214    serde_json::from_slice(&response)
215        .map_err(|err| SdkError::Other(format!("failed to decode pilot response: {err}")))
216}
217
218async fn start_pilot(runtime: &str, key: Option<&str>) -> SdkResult<()> {
219    let mut command = Command::new(find_rune_binary()?);
220    command
221        .arg("pilot")
222        .arg("daemon")
223        .arg("--runtime")
224        .arg(normalize_runtime(runtime))
225        .stdin(Stdio::null())
226        .stdout(Stdio::null())
227        .stderr(Stdio::null());
228    if let Some(key) = key {
229        command.env("RUNE_KEY", key);
230    }
231    #[cfg(unix)]
232    // SAFETY: `libc::setsid()` is async-signal-safe per POSIX.1-2017 ยง2.4.3,
233    // which is the requirement for functions called inside `pre_exec`.
234    // It detaches the child from the parent's session so the pilot daemon
235    // survives after the SDK process exits.
236    unsafe {
237        command.pre_exec(|| {
238            libc::setsid();
239            Ok(())
240        });
241    }
242    command
243        .spawn()
244        .map_err(|err| SdkError::Other(format!("failed to start pilot daemon: {err}")))?;
245    Ok(())
246}
247
248fn normalize_runtime(runtime: &str) -> String {
249    runtime.trim().trim_end_matches('/').to_string()
250}
251
252fn socket_path() -> SdkResult<PathBuf> {
253    Ok(home_dir()?.join(".rune").join("pilot.sock"))
254}
255
256fn find_rune_binary() -> SdkResult<PathBuf> {
257    if let Ok(path) = std::env::var("RUNE_BIN") {
258        return Ok(PathBuf::from(path));
259    }
260
261    if let Some(paths) = std::env::var_os("PATH") {
262        for dir in std::env::split_paths(&paths) {
263            let candidate = dir.join("rune");
264            if candidate.is_file() {
265                return Ok(candidate);
266            }
267        }
268    }
269
270    Err(SdkError::Other(
271        "failed to locate rune binary; set RUNE_BIN or add rune to PATH".into(),
272    ))
273}
274
275fn home_dir() -> SdkResult<PathBuf> {
276    std::env::var_os("HOME")
277        .map(PathBuf::from)
278        .ok_or_else(|| SdkError::Other("failed to determine HOME".into()))
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use serde_json::json;
285    use std::collections::VecDeque;
286    use std::ffi::OsString;
287    use std::fs;
288    use std::path::{Path, PathBuf};
289    use std::sync::{Mutex, OnceLock};
290    use std::time::{SystemTime, UNIX_EPOCH};
291    use tokio::net::UnixListener;
292
293    static ENV_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
294
295    struct HomeGuard {
296        previous: Option<OsString>,
297        root: PathBuf,
298        _lock: std::sync::MutexGuard<'static, ()>,
299    }
300
301    impl HomeGuard {
302        fn set() -> Self {
303            let lock = ENV_LOCK
304                .get_or_init(|| Mutex::new(()))
305                .lock()
306                .unwrap_or_else(|poisoned| poisoned.into_inner());
307            let unique = SystemTime::now()
308                .duration_since(UNIX_EPOCH)
309                .unwrap_or_default()
310                .as_millis();
311            let root = PathBuf::from(format!("/tmp/rpc-{}-{unique}", std::process::id()));
312            fs::create_dir_all(root.join(".rune")).unwrap();
313            let previous = std::env::var_os("HOME");
314            std::env::set_var("HOME", &root);
315            Self {
316                previous,
317                root,
318                _lock: lock,
319            }
320        }
321    }
322
323    impl Drop for HomeGuard {
324        fn drop(&mut self) {
325            if let Some(previous) = self.previous.take() {
326                std::env::set_var("HOME", previous);
327            } else {
328                std::env::remove_var("HOME");
329            }
330            let _ = fs::remove_dir_all(&self.root);
331        }
332    }
333
334    async fn spawn_status_server(
335        socket: &Path,
336        responses: Vec<serde_json::Value>,
337    ) -> tokio::task::JoinHandle<()> {
338        let _ = fs::remove_file(socket);
339        let listener = UnixListener::bind(socket).unwrap();
340        let responses = Mutex::new(VecDeque::from(responses));
341        tokio::spawn(async move {
342            loop {
343                let Some(response) = responses
344                    .lock()
345                    .unwrap_or_else(|poisoned| poisoned.into_inner())
346                    .pop_front()
347                else {
348                    break;
349                };
350                let (mut stream, _) = listener.accept().await.unwrap();
351                let mut request = Vec::new();
352                stream.read_to_end(&mut request).await.unwrap();
353                let payload: serde_json::Value = serde_json::from_slice(&request).unwrap();
354                assert_eq!(payload["command"], "status");
355                stream
356                    .write_all(&serde_json::to_vec(&response).unwrap())
357                    .await
358                    .unwrap();
359            }
360        })
361    }
362
363    #[tokio::test]
364    async fn test_fix_ensure_waits_for_matching_runtime_to_become_ready() {
365        let _home = HomeGuard::set();
366        let socket = socket_path().unwrap();
367        let server = spawn_status_server(
368            &socket,
369            vec![
370                json!({
371                    "ok": false,
372                    "pilot_id": "pilot-1",
373                    "runtime": "127.0.0.1:50051",
374                    "error": "runtime session not attached"
375                }),
376                json!({
377                    "ok": true,
378                    "pilot_id": "pilot-1",
379                    "runtime": "127.0.0.1:50051",
380                    "error": null
381                }),
382            ],
383        )
384        .await;
385
386        let client = PilotClient::ensure("127.0.0.1:50051", None)
387            .await
388            .expect("ensure should keep polling while pilot is still connecting");
389        assert_eq!(client.pilot_id(), "pilot-1");
390
391        server.await.unwrap();
392    }
393
394    /// Regression: PilotResponse must deserialize even when `ok` and `pilot_id`
395    /// are absent (e.g. error-only responses from pilot daemon).
396    #[test]
397    fn test_fix_pilot_response_deserialize_without_pilot_id() {
398        let json = r#"{"ok": false, "error": "connection refused"}"#;
399        let resp: PilotResponse =
400            serde_json::from_str(json).expect("should deserialize error response missing pilot_id");
401        assert!(!resp.ok);
402        assert_eq!(resp.pilot_id, "");
403        assert_eq!(resp.error.as_deref(), Some("connection refused"));
404
405        // Minimal error-only payload (no ok, no pilot_id, no runtime)
406        let json_minimal = r#"{"error": "socket not found"}"#;
407        let resp2: PilotResponse = serde_json::from_str(json_minimal)
408            .expect("should deserialize minimal error-only response");
409        assert!(!resp2.ok);
410        assert_eq!(resp2.pilot_id, "");
411        assert_eq!(resp2.runtime, "");
412        assert_eq!(resp2.error.as_deref(), Some("socket not found"));
413    }
414
415    /// Regression: empty-string error must retry (aligned with Python/TS SDKs).
416    /// Rust previously treated Some("") as Failed(""), while Python/TS treated
417    /// falsy error as retry.
418    #[test]
419    fn test_fix_classify_status_empty_error_retries() {
420        let response = PilotResponse {
421            ok: false,
422            pilot_id: "pilot-1".into(),
423            runtime: "127.0.0.1:50051".into(),
424            error: Some("".into()),
425        };
426        let result = PilotClient::classify_status(response, "127.0.0.1:50051");
427        assert!(
428            matches!(result, EnsureStatus::Retry),
429            "empty error string should classify as Retry, not Failed"
430        );
431    }
432}