Skip to main content

mcp_compressor_core/
oauth.rs

1//! OAuth helpers for remote MCP backends.
2//!
3//! The runtime delegates OAuth protocol details to `rmcp`. This module only
4//! provides compressor-specific storage and local callback plumbing.
5
6use std::env;
7use std::fs;
8use std::time::{SystemTime, UNIX_EPOCH};
9use std::io::{Read, Write};
10use std::net::{SocketAddr, TcpListener};
11use std::path::{Path, PathBuf};
12use std::time::Duration;
13
14use serde::{Deserialize, Serialize};
15
16use rmcp::transport::auth::{
17    AuthError, CredentialStore, StateStore, StoredAuthorizationState, StoredCredentials,
18};
19
20const OAUTH_TOKEN_DIR_NAME: &str = "oauth-tokens-rust";
21
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub struct OAuthStoreIndexEntry {
24    pub name: String,
25    pub uri: String,
26    pub store_dir: String,
27}
28
29/// File-backed OAuth credential store.
30#[derive(Debug, Clone)]
31pub struct FileCredentialStore {
32    path: PathBuf,
33}
34
35impl FileCredentialStore {
36    pub fn new(path: impl Into<PathBuf>) -> Self {
37        Self { path: path.into() }
38    }
39
40    pub fn path(&self) -> &Path {
41        &self.path
42    }
43}
44
45#[async_trait::async_trait]
46impl CredentialStore for FileCredentialStore {
47    async fn load(&self) -> Result<Option<StoredCredentials>, AuthError> {
48        let Some(contents) = read_optional(&self.path)? else {
49            return Ok(None);
50        };
51        serde_json::from_str(&contents).map(Some).map_err(|error| {
52            AuthError::InternalError(format!("failed to parse OAuth credentials: {error}"))
53        })
54    }
55
56    async fn save(&self, credentials: StoredCredentials) -> Result<(), AuthError> {
57        write_json(&self.path, &credentials)
58    }
59
60    async fn clear(&self) -> Result<(), AuthError> {
61        remove_optional(&self.path)
62    }
63}
64
65/// File-backed OAuth authorization-state store.
66#[derive(Debug, Clone)]
67pub struct FileStateStore {
68    dir: PathBuf,
69}
70
71impl FileStateStore {
72    pub fn new(dir: impl Into<PathBuf>) -> Self {
73        Self { dir: dir.into() }
74    }
75
76    pub fn dir(&self) -> &Path {
77        &self.dir
78    }
79
80    fn state_path(&self, csrf_token: &str) -> PathBuf {
81        self.dir
82            .join(format!("{}.json", sanitize_file_component(csrf_token)))
83    }
84}
85
86#[async_trait::async_trait]
87impl StateStore for FileStateStore {
88    async fn save(
89        &self,
90        csrf_token: &str,
91        state: StoredAuthorizationState,
92    ) -> Result<(), AuthError> {
93        write_json(&self.state_path(csrf_token), &state)
94    }
95
96    async fn load(&self, csrf_token: &str) -> Result<Option<StoredAuthorizationState>, AuthError> {
97        let Some(contents) = read_optional(&self.state_path(csrf_token))? else {
98            return Ok(None);
99        };
100        serde_json::from_str(&contents).map(Some).map_err(|error| {
101            AuthError::InternalError(format!("failed to parse OAuth state: {error}"))
102        })
103    }
104
105    async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> {
106        remove_optional(&self.state_path(csrf_token))
107    }
108}
109
110fn read_optional(path: &Path) -> Result<Option<String>, AuthError> {
111    match fs::read_to_string(path) {
112        Ok(contents) => Ok(Some(contents)),
113        Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(None),
114        Err(error) => Err(AuthError::InternalError(format!(
115            "failed to read OAuth store {}: {error}",
116            path.display()
117        ))),
118    }
119}
120
121fn write_json<T: serde::Serialize>(path: &Path, value: &T) -> Result<(), AuthError> {
122    if let Some(parent) = path.parent() {
123        fs::create_dir_all(parent).map_err(|error| {
124            AuthError::InternalError(format!(
125                "failed to create OAuth store directory {}: {error}",
126                parent.display()
127            ))
128        })?;
129    }
130    let json = serde_json::to_string_pretty(value).map_err(|error| {
131        AuthError::InternalError(format!("failed to serialize OAuth store: {error}"))
132    })?;
133    atomic_write(path, json.as_bytes()).map_err(|error| {
134        AuthError::InternalError(format!(
135            "failed to write OAuth store {}: {error}",
136            path.display()
137        ))
138    })
139}
140
141fn atomic_write(path: &Path, contents: &[u8]) -> Result<(), std::io::Error> {
142    if let Some(parent) = path.parent() {
143        fs::create_dir_all(parent)?;
144    }
145    let parent = path.parent().unwrap_or_else(|| Path::new("."));
146    let file_name = path
147        .file_name()
148        .and_then(|name| name.to_str())
149        .unwrap_or("store.json");
150    let nonce = SystemTime::now()
151        .duration_since(UNIX_EPOCH)
152        .map(|duration| duration.as_nanos())
153        .unwrap_or_default();
154    let tmp_path = parent.join(format!(".{file_name}.{nonce}.tmp"));
155    fs::write(&tmp_path, contents)?;
156    fs::rename(&tmp_path, path).or_else(|error| {
157        let _ = fs::remove_file(&tmp_path);
158        Err(error)
159    })
160}
161
162fn remove_optional(path: &Path) -> Result<(), AuthError> {
163    match fs::remove_file(path) {
164        Ok(()) => Ok(()),
165        Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
166        Err(error) => Err(AuthError::InternalError(format!(
167            "failed to remove OAuth store {}: {error}",
168            path.display()
169        ))),
170    }
171}
172
173/// Local OAuth callback listener bound to loopback.
174#[derive(Debug)]
175pub struct OAuthCallbackListener {
176    listener: TcpListener,
177    redirect_uri: String,
178}
179
180impl OAuthCallbackListener {
181    pub fn bind() -> Result<Self, std::io::Error> {
182        let listener = TcpListener::bind("127.0.0.1:0")?;
183        let addr = listener.local_addr()?;
184        Ok(Self {
185            listener,
186            redirect_uri: format!("http://{addr}/callback"),
187        })
188    }
189
190    pub fn redirect_uri(&self) -> &str {
191        &self.redirect_uri
192    }
193
194    pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
195        self.listener.local_addr()
196    }
197
198    pub fn wait_for_callback(self) -> Result<OAuthCallback, std::io::Error> {
199        let (mut stream, _) = self.listener.accept()?;
200        stream.set_read_timeout(Some(Duration::from_secs(30)))?;
201        let mut request = [0_u8; 8192];
202        let bytes = stream.read(&mut request)?;
203        let request = String::from_utf8_lossy(&request[..bytes]);
204        match parse_callback_request(&request) {
205            OAuthCallbackResult::Success(callback) => {
206                write_callback_response(
207                    &mut stream,
208                    200,
209                    "OAuth complete. You can close this tab and return to mcp-compressor.",
210                )?;
211                Ok(callback)
212            }
213            OAuthCallbackResult::ProviderError { error, description } => {
214                write_callback_response(
215                    &mut stream,
216                    400,
217                    "OAuth authorization failed. You can close this tab and return to mcp-compressor.",
218                )?;
219                Err(std::io::Error::new(
220                    std::io::ErrorKind::PermissionDenied,
221                    format_callback_provider_error(&error, description.as_deref()),
222                ))
223            }
224            OAuthCallbackResult::Malformed(reason) => {
225                write_callback_response(
226                    &mut stream,
227                    400,
228                    "OAuth callback was missing required parameters. You can close this tab.",
229                )?;
230                Err(std::io::Error::new(std::io::ErrorKind::InvalidData, reason))
231            }
232        }
233    }
234}
235
236pub fn open_authorization_url(url: &str) -> Result<BrowserOpenStatus, std::io::Error> {
237    if browser_open_disabled() {
238        return Ok(BrowserOpenStatus::Disabled);
239    }
240    open::that(url)
241        .map(|_| BrowserOpenStatus::Opened)
242        .map_err(std::io::Error::other)
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum BrowserOpenStatus {
247    Opened,
248    Disabled,
249}
250
251fn browser_open_disabled() -> bool {
252    env::var("MCP_COMPRESSOR_NO_BROWSER")
253        .map(|value| matches!(value.as_str(), "1" | "true" | "TRUE" | "yes" | "YES"))
254        .unwrap_or(false)
255}
256
257#[derive(Debug, Clone, PartialEq, Eq)]
258pub struct OAuthCallback {
259    pub code: String,
260    pub state: String,
261}
262
263#[derive(Debug, Clone, PartialEq, Eq)]
264enum OAuthCallbackResult {
265    Success(OAuthCallback),
266    ProviderError {
267        error: String,
268        description: Option<String>,
269    },
270    Malformed(String),
271}
272
273fn parse_callback_request(request: &str) -> OAuthCallbackResult {
274    let Some(first_line) = request.lines().next() else {
275        return OAuthCallbackResult::Malformed("OAuth callback request was empty".to_string());
276    };
277    let Some(path) = first_line.split_whitespace().nth(1) else {
278        return OAuthCallbackResult::Malformed(
279            "OAuth callback request line was invalid".to_string(),
280        );
281    };
282    let Some(query) = path.split_once('?').map(|(_, query)| query) else {
283        return OAuthCallbackResult::Malformed(
284            "OAuth callback query string was missing".to_string(),
285        );
286    };
287    let mut code = None;
288    let mut state = None;
289    let mut error = None;
290    let mut error_description = None;
291    for pair in query.split('&') {
292        let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
293        match key {
294            "code" => code = Some(percent_decode(value)),
295            "state" => state = Some(percent_decode(value)),
296            "error" => error = Some(percent_decode(value)),
297            "error_description" => error_description = Some(percent_decode(value)),
298            _ => {}
299        }
300    }
301    if let Some(error) = error {
302        return OAuthCallbackResult::ProviderError {
303            error,
304            description: error_description,
305        };
306    }
307    match (code, state) {
308        (Some(code), Some(state)) if !code.is_empty() && !state.is_empty() => {
309            OAuthCallbackResult::Success(OAuthCallback { code, state })
310        }
311        _ => OAuthCallbackResult::Malformed(
312            "OAuth callback was missing non-empty code or state".to_string(),
313        ),
314    }
315}
316
317fn write_callback_response(
318    stream: &mut impl Write,
319    status: u16,
320    body: &str,
321) -> Result<(), std::io::Error> {
322    let status_text = match status {
323        200 => "OK",
324        400 => "Bad Request",
325        _ => "OK",
326    };
327    let html = callback_html(status, body);
328    let response = format!(
329        "HTTP/1.1 {status} {status_text}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
330        html.len(),
331        html
332    );
333    stream.write_all(response.as_bytes())
334}
335
336fn callback_html(status: u16, body: &str) -> String {
337    let success = status == 200;
338    let title = if success {
339        "Authorization complete"
340    } else {
341        "Authorization failed"
342    };
343    let message = if success {
344        "The requesting app has been successfully authorized. You may close this window."
345            .to_string()
346    } else {
347        escape_html(body)
348    };
349    let accent = if success { "#22A06B" } else { "#AE2E24" };
350    let icon = if success { "&#10003;" } else { "!" };
351    format!(
352        r#"<!doctype html>
353<html lang="en">
354<head>
355  <meta charset="utf-8">
356  <meta name="viewport" content="width=device-width, initial-scale=1">
357  <title>{title}</title>
358  <style>
359    :root {{
360      --text: #172B4D;
361      --text-subtle: #44546F;
362      --surface: #FFFFFF;
363      --bg: #F4F5F7;
364      --border: rgba(9, 30, 66, 0.13);
365      --shadow: 0 10px 10px rgba(0, 0, 0, 0.1);
366      --accent: {accent};
367      --font: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, "Fira Sans", "Droid Sans", "Helvetica Neue", sans-serif;
368    }}
369    * {{ box-sizing: border-box; }}
370    html, body {{ min-height: 100%; margin: 0; }}
371    body {{
372      min-height: 100vh;
373      background: var(--bg);
374      color: var(--text);
375      font-family: var(--font);
376      font-size: 14px;
377      line-height: 1.5;
378      display: flex;
379      align-items: center;
380      justify-content: center;
381      padding: 48px 24px;
382    }}
383    .card {{
384      background: var(--surface);
385      border-radius: 8px;
386      box-shadow: var(--shadow);
387      width: min(400px, 100%);
388      padding: 48px 40px;
389      text-align: center;
390    }}
391    .icon {{
392      width: 64px;
393      height: 64px;
394      border-radius: 50%;
395      background: var(--accent);
396      color: #fff;
397      font-size: 32px;
398      line-height: 64px;
399      margin: 0 auto 24px;
400    }}
401    h1 {{
402      margin: 0 0 12px;
403      font-size: 20px;
404      font-weight: 600;
405      color: var(--text);
406    }}
407    p {{
408      margin: 0;
409      color: var(--text-subtle);
410    }}
411  </style>
412</head>
413<body>
414  <main class="card" role="main">
415    <div class="icon" aria-hidden="true">{icon}</div>
416    <h1>{title}</h1>
417    <p>{message}</p>
418  </main>
419</body>
420</html>"#,
421    )
422}
423
424fn escape_html(value: &str) -> String {
425    value
426        .replace('&', "&amp;")
427        .replace('<', "&lt;")
428        .replace('>', "&gt;")
429        .replace('"', "&quot;")
430        .replace('\'', "&#39;")
431}
432
433fn format_callback_provider_error(error: &str, description: Option<&str>) -> String {
434    match description {
435        Some(description) if !description.is_empty() => {
436            format!("OAuth provider returned {error}: {description}")
437        }
438        _ => format!("OAuth provider returned {error}"),
439    }
440}
441
442fn percent_decode(value: &str) -> String {
443    let mut output = Vec::with_capacity(value.len());
444    let bytes = value.as_bytes();
445    let mut index = 0;
446    while index < bytes.len() {
447        match bytes[index] {
448            b'+' => {
449                output.push(b' ');
450                index += 1;
451            }
452            b'%' if index + 2 < bytes.len() => {
453                if let (Some(high), Some(low)) =
454                    (hex_value(bytes[index + 1]), hex_value(bytes[index + 2]))
455                {
456                    output.push((high << 4) | low);
457                    index += 3;
458                } else {
459                    output.push(bytes[index]);
460                    index += 1;
461                }
462            }
463            byte => {
464                output.push(byte);
465                index += 1;
466            }
467        }
468    }
469    String::from_utf8_lossy(&output).into_owned()
470}
471
472fn hex_value(byte: u8) -> Option<u8> {
473    match byte {
474        b'0'..=b'9' => Some(byte - b'0'),
475        b'a'..=b'f' => Some(byte - b'a' + 10),
476        b'A'..=b'F' => Some(byte - b'A' + 10),
477        _ => None,
478    }
479}
480
481pub fn oauth_store_root() -> PathBuf {
482    dirs::config_dir()
483        .unwrap_or_else(|| PathBuf::from("."))
484        .join("mcp-compressor")
485        .join(OAUTH_TOKEN_DIR_NAME)
486}
487
488pub fn oauth_store_dir(uri: &str, name: &str) -> PathBuf {
489    oauth_store_root().join(sanitize_file_component(&format!("{name}-{uri}")))
490}
491
492pub fn remember_oauth_store(uri: &str, name: &str, store_dir: &Path) -> Result<(), std::io::Error> {
493    let root = oauth_store_root();
494    fs::create_dir_all(&root)?;
495    let index_path = root.join("index.json");
496    let mut entries = read_oauth_store_index_from(&index_path)?;
497    let store_dir = store_dir.to_string_lossy().into_owned();
498    entries.retain(|entry| !(entry.name == name && entry.uri == uri));
499    entries.push(OAuthStoreIndexEntry {
500        name: name.to_string(),
501        uri: uri.to_string(),
502        store_dir,
503    });
504    entries.sort_by(|left, right| left.name.cmp(&right.name).then(left.uri.cmp(&right.uri)));
505    write_oauth_index(&index_path, &entries)
506}
507
508pub fn list_oauth_stores() -> Result<Vec<OAuthStoreIndexEntry>, std::io::Error> {
509    read_oauth_store_index_from(&oauth_store_root().join("index.json"))
510}
511
512pub fn clear_oauth_store(target: Option<&str>) -> Result<Vec<PathBuf>, std::io::Error> {
513    let root = oauth_store_root();
514    let index_path = root.join("index.json");
515    if !root.exists() {
516        return Ok(Vec::new());
517    }
518    let entries = read_oauth_store_index_from(&index_path)?;
519    let mut removed = Vec::new();
520    if let Some(target) = target {
521        for entry in entries
522            .iter()
523            .filter(|entry| entry.name == target || entry.uri == target)
524        {
525            let path = PathBuf::from(&entry.store_dir);
526            if path.exists() {
527                fs::remove_dir_all(&path)?;
528                removed.push(path);
529            }
530        }
531        let remaining = entries
532            .into_iter()
533            .filter(|entry| entry.name != target && entry.uri != target)
534            .collect::<Vec<_>>();
535        write_oauth_index(&index_path, &remaining)?;
536    } else {
537        fs::remove_dir_all(&root)?;
538        removed.push(root);
539    }
540    Ok(removed)
541}
542
543fn read_oauth_store_index_from(path: &Path) -> Result<Vec<OAuthStoreIndexEntry>, std::io::Error> {
544    match fs::read_to_string(path) {
545        Ok(contents) => serde_json::from_str(&contents).map_err(|error| {
546            std::io::Error::new(
547                std::io::ErrorKind::InvalidData,
548                format!("failed to parse OAuth store index {}: {error}", path.display()),
549            )
550        }),
551        Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(Vec::new()),
552        Err(error) => Err(error),
553    }
554}
555
556fn write_oauth_index(
557    path: &Path,
558    entries: &[OAuthStoreIndexEntry],
559) -> Result<(), std::io::Error> {
560    let json = serde_json::to_string_pretty(entries).map_err(|error| {
561        std::io::Error::new(
562            std::io::ErrorKind::InvalidData,
563            format!("failed to serialize OAuth store index {}: {error}", path.display()),
564        )
565    })?;
566    atomic_write(path, json.as_bytes())
567}
568
569fn sanitize_file_component(value: &str) -> String {
570    let sanitized = value
571        .chars()
572        .map(|ch| {
573            if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
574                ch
575            } else {
576                '_'
577            }
578        })
579        .collect::<String>();
580    if sanitized.is_empty() {
581        "state".to_string()
582    } else {
583        sanitized
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[tokio::test]
592    async fn file_credential_store_missing_loads_none_and_clear_is_idempotent() {
593        let dir = tempfile::tempdir().unwrap();
594        let store = FileCredentialStore::new(dir.path().join("credentials.json"));
595
596        assert!(store.load().await.unwrap().is_none());
597        store.clear().await.unwrap();
598    }
599
600    #[tokio::test]
601    async fn file_state_store_missing_loads_none_and_delete_is_idempotent() {
602        let dir = tempfile::tempdir().unwrap();
603        let store = FileStateStore::new(dir.path().join("state"));
604
605        assert!(store.load("missing-token").await.unwrap().is_none());
606        store.delete("missing-token").await.unwrap();
607    }
608
609    #[test]
610    fn remember_and_clear_oauth_store_index_entries() {
611        let root = tempfile::tempdir().unwrap();
612        let index_path = root.path().join("index.json");
613        let store_dir = root.path().join("store");
614        std::fs::create_dir_all(&store_dir).unwrap();
615        std::fs::write(store_dir.join("credentials.json"), "{}").unwrap();
616        let entry = OAuthStoreIndexEntry {
617            name: "alpha".to_string(),
618            uri: "https://example.test/mcp".to_string(),
619            store_dir: store_dir.to_string_lossy().into_owned(),
620        };
621        std::fs::write(
622            &index_path,
623            serde_json::to_string_pretty(&vec![entry]).unwrap(),
624        )
625        .unwrap();
626
627        let entries = read_oauth_store_index_from(&index_path).unwrap();
628        assert_eq!(entries.len(), 1);
629        assert_eq!(entries[0].name, "alpha");
630    }
631
632    #[test]
633    fn oauth_index_corruption_is_reported() {
634        let root = tempfile::tempdir().unwrap();
635        let index_path = root.path().join("index.json");
636        std::fs::write(&index_path, "not json").unwrap();
637
638        let error = read_oauth_store_index_from(&index_path).unwrap_err();
639        assert_eq!(error.kind(), std::io::ErrorKind::InvalidData);
640        assert!(error.to_string().contains("failed to parse OAuth store index"));
641    }
642
643    #[test]
644    fn oauth_index_writes_are_atomic_and_do_not_leave_temp_files() {
645        let root = tempfile::tempdir().unwrap();
646        let index_path = root.path().join("index.json");
647        let entries = vec![OAuthStoreIndexEntry {
648            name: "alpha".to_string(),
649            uri: "https://example.test/mcp".to_string(),
650            store_dir: root.path().join("store").to_string_lossy().into_owned(),
651        }];
652
653        write_oauth_index(&index_path, &entries).unwrap();
654        assert_eq!(read_oauth_store_index_from(&index_path).unwrap(), entries);
655        let temp_files = std::fs::read_dir(root.path())
656            .unwrap()
657            .filter_map(Result::ok)
658            .filter(|entry| entry.file_name().to_string_lossy().contains(".tmp"))
659            .collect::<Vec<_>>();
660        assert!(temp_files.is_empty());
661    }
662
663    #[test]
664    fn browser_open_can_be_disabled_for_headless_runs() {
665        unsafe {
666            std::env::set_var("MCP_COMPRESSOR_NO_BROWSER", "1");
667        }
668        assert_eq!(
669            open_authorization_url("https://example.test/auth").unwrap(),
670            BrowserOpenStatus::Disabled
671        );
672        unsafe {
673            std::env::remove_var("MCP_COMPRESSOR_NO_BROWSER");
674        }
675    }
676
677    #[test]
678    fn callback_request_parser_extracts_and_decodes_code_and_state() {
679        let callback = parse_callback_request(
680            "GET /callback?code=abc%20123&state=state+value HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n",
681        );
682
683        assert_eq!(
684            callback,
685            OAuthCallbackResult::Success(OAuthCallback {
686                code: "abc 123".to_string(),
687                state: "state value".to_string(),
688            })
689        );
690    }
691
692    #[test]
693    fn callback_request_parser_reports_provider_errors() {
694        let callback = parse_callback_request(
695            "GET /callback?error=access_denied&error_description=user+cancelled HTTP/1.1\r\n\r\n",
696        );
697
698        assert_eq!(
699            callback,
700            OAuthCallbackResult::ProviderError {
701                error: "access_denied".to_string(),
702                description: Some("user cancelled".to_string()),
703            }
704        );
705    }
706
707    #[test]
708    fn callback_request_parser_rejects_missing_fields() {
709        assert!(matches!(
710            parse_callback_request("GET /callback?code=abc HTTP/1.1\r\n\r\n"),
711            OAuthCallbackResult::Malformed(_)
712        ));
713        assert!(matches!(
714            parse_callback_request("GET /callback?state=abc HTTP/1.1\r\n\r\n"),
715            OAuthCallbackResult::Malformed(_)
716        ));
717    }
718
719    #[test]
720    fn callback_response_escapes_html_body() {
721        let mut response = Vec::new();
722        write_callback_response(&mut response, 200, "OAuth complete <script>").unwrap();
723        let response = String::from_utf8(response).unwrap();
724
725        assert!(response.starts_with("HTTP/1.1 200 OK"));
726        assert!(response.contains("Authorization complete"));
727        assert!(response.contains("The requesting app has been successfully authorized"));
728        assert!(!response.contains("OAuth complete <script>"));
729        // removed Atlassian-specific assertion);
730        assert!(response.contains("You may close this window"));
731    }
732
733    #[test]
734    fn callback_response_writes_status_and_body() {
735        let mut response = Vec::new();
736        write_callback_response(&mut response, 400, "nope").unwrap();
737        let response = String::from_utf8(response).unwrap();
738
739        assert!(response.starts_with("HTTP/1.1 400 Bad Request"));
740        assert!(response.contains("Content-Type: text/html; charset=utf-8"));
741        assert!(response.contains("Authorization failed"));
742        assert!(response.contains("nope"));
743        assert!(response.contains("Authorization failed"));
744    }
745
746    #[test]
747    fn state_store_sanitizes_file_components() {
748        let store = FileStateStore::new("state-dir");
749
750        assert_eq!(
751            store.state_path("abc/../def").file_name().unwrap(),
752            "abc____def.json"
753        );
754        assert_eq!(store.state_path("").file_name().unwrap(), "state.json");
755    }
756}