Skip to main content

mcp_compressor_core/
sdk.rs

1use std::collections::{BTreeMap, HashMap};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use serde_json::{json, Value};
8
9use crate::client_gen::cli::CliGenerator;
10use crate::client_gen::generator::{ClientGenerator, GeneratorConfig};
11use crate::client_gen::python::PythonGenerator;
12use crate::client_gen::typescript::TypeScriptGenerator;
13use crate::compression::engine::Tool;
14use crate::compression::CompressionLevel;
15use crate::ffi::{normalize_sdk_servers, FfiSdkServerConfig, FfiSdkServersConfig};
16use crate::proxy::{RunningToolProxy, ToolProxyServer};
17use crate::server::{BackendAuthMode, BackendServerConfig};
18use crate::server::{CompressedServer, CompressedServerConfig, ProxyTransformMode};
19use crate::Error;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CompressorMode {
23    CompressedTools,
24    Cli,
25    JustBash,
26}
27
28impl From<CompressorMode> for ProxyTransformMode {
29    fn from(value: CompressorMode) -> Self {
30        match value {
31            CompressorMode::CompressedTools => Self::CompressedTools,
32            CompressorMode::Cli => Self::Cli,
33            CompressorMode::JustBash => Self::JustBash,
34        }
35    }
36}
37
38type HeaderProvider = Arc<dyn Fn() -> Result<BTreeMap<String, String>, Error> + Send + Sync>;
39
40#[derive(Clone)]
41pub struct ServerConfig {
42    inner: FfiSdkServerConfig,
43    auth_provider: Option<HeaderProvider>,
44    oauth_app_name: Option<String>,
45}
46
47impl ServerConfig {
48    pub fn command(command: impl Into<String>) -> Self {
49        Self {
50            inner: FfiSdkServerConfig::Structured {
51                command: Some(command.into()),
52                url: None,
53                args: Vec::new(),
54                headers: BTreeMap::new(),
55                oauth_app_name: None,
56            },
57            auth_provider: None,
58            oauth_app_name: None,
59        }
60    }
61
62    pub fn url(url: impl Into<String>) -> Self {
63        Self {
64            inner: FfiSdkServerConfig::Structured {
65                command: None,
66                url: Some(url.into()),
67                args: Vec::new(),
68                headers: BTreeMap::new(),
69                oauth_app_name: None,
70            },
71            auth_provider: None,
72            oauth_app_name: None,
73        }
74    }
75
76    pub fn arg(mut self, arg: impl Into<String>) -> Self {
77        if let FfiSdkServerConfig::Structured { args, .. } = &mut self.inner {
78            args.push(arg.into());
79        }
80        self
81    }
82
83    pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
84        if let FfiSdkServerConfig::Structured { args: stored, .. } = &mut self.inner {
85            stored.extend(args.into_iter().map(Into::into));
86        }
87        self
88    }
89
90    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
91        if let FfiSdkServerConfig::Structured { headers, .. } = &mut self.inner {
92            headers.insert(name.into(), value.into());
93        }
94        self
95    }
96
97    pub fn auth_provider(
98        mut self,
99        provider: impl Fn() -> Result<BTreeMap<String, String>, Error> + Send + Sync + 'static,
100    ) -> Self {
101        self.auth_provider = Some(Arc::new(provider));
102        self
103    }
104
105    pub fn oauth_app_name(mut self, app_name: impl Into<String>) -> Self {
106        self.oauth_app_name = Some(app_name.into());
107        self
108    }
109
110    fn materialize(mut self) -> (FfiSdkServerConfig, Option<HeaderProvider>) {
111        if let (FfiSdkServerConfig::Structured { oauth_app_name, .. }, Some(app_name)) =
112            (&mut self.inner, self.oauth_app_name.take())
113        {
114            *oauth_app_name = Some(app_name);
115        }
116        (self.inner, self.auth_provider.take())
117    }
118}
119
120#[derive(Clone)]
121pub struct CompressorClientBuilder {
122    servers: BTreeMap<String, ServerConfig>,
123    compression_level: CompressionLevel,
124    server_name: Option<String>,
125    include_tools: Vec<String>,
126    exclude_tools: Vec<String>,
127    toonify: bool,
128    mode: CompressorMode,
129}
130
131impl Default for CompressorClientBuilder {
132    fn default() -> Self {
133        Self {
134            servers: BTreeMap::new(),
135            compression_level: CompressionLevel::Max,
136            server_name: None,
137            include_tools: Vec::new(),
138            exclude_tools: Vec::new(),
139            toonify: false,
140            mode: CompressorMode::CompressedTools,
141        }
142    }
143}
144
145impl CompressorClientBuilder {
146    pub fn server(mut self, name: impl Into<String>, config: ServerConfig) -> Self {
147        self.servers.insert(name.into(), config);
148        self
149    }
150
151    pub fn compression_level(mut self, level: CompressionLevel) -> Self {
152        self.compression_level = level;
153        self
154    }
155
156    pub fn server_name(mut self, server_name: impl Into<String>) -> Self {
157        self.server_name = Some(server_name.into());
158        self
159    }
160
161    pub fn include_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
162        self.include_tools = tools.into_iter().map(Into::into).collect();
163        self
164    }
165
166    pub fn exclude_tools(mut self, tools: impl IntoIterator<Item = impl Into<String>>) -> Self {
167        self.exclude_tools = tools.into_iter().map(Into::into).collect();
168        self
169    }
170
171    pub fn toonify(mut self, enabled: bool) -> Self {
172        self.toonify = enabled;
173        self
174    }
175
176    pub fn mode(mut self, mode: CompressorMode) -> Self {
177        self.mode = mode;
178        self
179    }
180
181    pub fn build(self) -> CompressorClient {
182        CompressorClient { builder: self }
183    }
184}
185
186#[derive(Clone)]
187pub struct CompressorClient {
188    builder: CompressorClientBuilder,
189}
190
191impl CompressorClient {
192    pub fn builder() -> CompressorClientBuilder {
193        CompressorClientBuilder::default()
194    }
195
196    pub async fn connect(&self) -> Result<CompressorProxy, Error> {
197        let materialized = self
198            .builder
199            .servers
200            .clone()
201            .into_iter()
202            .map(|(name, config)| {
203                let (config, provider) = config.materialize();
204                (name, config, provider)
205            })
206            .collect::<Vec<_>>();
207        let providers = materialized
208            .iter()
209            .filter_map(|(name, _, provider)| {
210                provider.clone().map(|provider| (name.clone(), provider))
211            })
212            .collect::<BTreeMap<_, _>>();
213        let ffi_configs = materialized
214            .into_iter()
215            .map(|(name, config, _)| (name, config))
216            .collect::<Vec<_>>();
217        let backends = normalize_sdk_servers(FfiSdkServersConfig::from_iter(ffi_configs))?;
218        let backends = backends
219            .into_iter()
220            .map(|backend| {
221                let name = backend.name.clone();
222                let mut backend = BackendServerConfig::from(backend);
223                if let Some(provider) = providers.get(&name) {
224                    backend = backend
225                        .with_header_provider(Arc::clone(provider))
226                        .with_auth_mode(BackendAuthMode::ExplicitHeaders);
227                }
228                backend
229            })
230            .collect::<Vec<_>>();
231        let server = CompressedServer::connect_multi_stdio(
232            CompressedServerConfig {
233                level: self.builder.compression_level.clone(),
234                server_name: self.builder.server_name.clone(),
235                include_tools: self.builder.include_tools.clone(),
236                exclude_tools: self.builder.exclude_tools.clone(),
237                toonify: self.builder.toonify,
238                transform_mode: self.builder.mode.into(),
239                ..CompressedServerConfig::default()
240            },
241            backends,
242        )
243        .await?;
244        CompressorProxy::start(server).await
245    }
246}
247
248pub struct CompressorProxy {
249    default_server: Option<String>,
250    frontend_tools: Vec<Tool>,
251    backend_tools: Vec<Tool>,
252    backend_tools_by_server: Vec<(String, Tool)>,
253    just_bash_providers: Vec<crate::server::JustBashProviderSpec>,
254    proxy: RunningToolProxy,
255}
256
257impl CompressorProxy {
258    async fn start(server: CompressedServer) -> Result<Self, Error> {
259        let default_server = server.default_server_name().map(str::to_string);
260        let frontend_tools = server.list_frontend_tools().await?;
261        let backend_tools = server.backend_tools();
262        let backend_tools_by_server = server.backend_tools_by_server();
263        let just_bash_providers = server.just_bash_provider_specs();
264        let proxy = ToolProxyServer::start(server).await?;
265        Ok(Self {
266            default_server,
267            frontend_tools,
268            backend_tools,
269            backend_tools_by_server,
270            just_bash_providers,
271            proxy,
272        })
273    }
274
275    pub fn bridge_url(&self) -> &str {
276        self.proxy.bridge_url()
277    }
278
279    pub fn token(&self) -> &str {
280        self.proxy.token_value()
281    }
282
283    pub fn tools(&self) -> &[Tool] {
284        &self.frontend_tools
285    }
286
287    pub fn backend_tools(&self) -> &[Tool] {
288        &self.backend_tools
289    }
290
291    pub fn just_bash_providers(&self) -> &[crate::server::JustBashProviderSpec] {
292        &self.just_bash_providers
293    }
294
295    pub fn schema(&self, tool_name: &str) -> Result<Value, Error> {
296        self.schema_on(self.default_server.as_deref(), tool_name)
297    }
298
299    pub fn schema_on(&self, server: Option<&str>, tool_name: &str) -> Result<Value, Error> {
300        let matches = self
301            .backend_tools_by_server
302            .iter()
303            .filter(|(server_name, tool)| {
304                tool.name == tool_name && server.map(|server| server == server_name).unwrap_or(true)
305            })
306            .collect::<Vec<_>>();
307        match matches.as_slice() {
308            [(_, tool)] => Ok(tool.input_schema.clone()),
309            [] => Err(Error::ToolNotFound(tool_name.to_string())),
310            _ => Err(Error::Config(
311                "Multiple backend tools matched; specify a server".to_string(),
312            )),
313        }
314    }
315
316    pub async fn invoke(&self, tool_name: &str, input: Value) -> Result<String, Error> {
317        self.invoke_on(self.default_server.as_deref(), tool_name, input)
318            .await
319    }
320
321    pub async fn invoke_on(
322        &self,
323        server: Option<&str>,
324        tool_name: &str,
325        input: Value,
326    ) -> Result<String, Error> {
327        let wrapper = self.invoke_wrapper(server)?;
328        self.invoke_wrapper_tool(
329            &wrapper,
330            json!({
331                "tool_name": tool_name,
332                "tool_input": input,
333            }),
334        )
335        .await
336    }
337
338    async fn invoke_wrapper_tool(&self, wrapper: &str, input: Value) -> Result<String, Error> {
339        let client = reqwest::Client::new();
340        let response = client
341            .post(self.proxy.exec_url())
342            .header("Authorization", format!("Bearer {}", self.token()))
343            .json(&json!({
344                "tool": wrapper,
345                "input": input
346            }))
347            .send()
348            .await
349            .map_err(|error| Error::Config(format!("proxy request failed: {error}")))?;
350        let status = response.status();
351        let text = response
352            .text()
353            .await
354            .map_err(|error| Error::Config(format!("proxy response failed: {error}")))?;
355        if status.is_success() {
356            Ok(text)
357        } else {
358            Err(Error::Config(format!(
359                "proxy request failed with {status}: {text}"
360            )))
361        }
362    }
363
364    pub fn executable_tools(&self) -> BTreeMap<String, Box<dyn ExecutableTool + '_>> {
365        self.frontend_tools
366            .iter()
367            .map(|tool| {
368                (
369                    tool.name.clone(),
370                    Box::new(ProxyExecutableTool { proxy: self, tool: tool.clone() })
371                        as Box<dyn ExecutableTool>,
372                )
373            })
374            .collect()
375    }
376
377    pub fn write_cli_client(
378        &self,
379        output_dir: impl AsRef<Path>,
380        name: Option<&str>,
381    ) -> Result<GeneratedClient, Error> {
382        self.write_client(GeneratedClientKind::Cli, output_dir, name)
383    }
384
385    pub fn write_code_client(
386        &self,
387        language: CodeLanguage,
388        output_dir: impl AsRef<Path>,
389        name: Option<&str>,
390    ) -> Result<GeneratedClient, Error> {
391        let kind = match language {
392            CodeLanguage::Python => GeneratedClientKind::Python,
393            CodeLanguage::TypeScript => GeneratedClientKind::TypeScript,
394        };
395        self.write_client(kind, output_dir, name.into())
396    }
397
398    pub fn write_client(
399        &self,
400        kind: GeneratedClientKind,
401        output_dir: impl AsRef<Path>,
402        name: Option<&str>,
403    ) -> Result<GeneratedClient, Error> {
404        let generator_config = GeneratorConfig {
405            cli_name: name
406                .or(self.default_server.as_deref())
407                .unwrap_or("mcp")
408                .to_string(),
409            bridge_url: self.bridge_url().to_string(),
410            token: self.token().to_string(),
411            tools: self.backend_tools.clone(),
412            session_pid: 0,
413            output_dir: output_dir.as_ref().to_path_buf(),
414        };
415        let files = match kind {
416            GeneratedClientKind::Cli => CliGenerator.generate(&generator_config),
417            GeneratedClientKind::Python => PythonGenerator.generate(&generator_config),
418            GeneratedClientKind::TypeScript => TypeScriptGenerator.generate(&generator_config),
419        }?;
420        let environment = kind.environment(&generator_config);
421        Ok(GeneratedClient {
422            kind,
423            output_dir: generator_config.output_dir,
424            files,
425            environment,
426        })
427    }
428
429    fn invoke_wrapper(&self, server: Option<&str>) -> Result<String, Error> {
430        let suffix = "_invoke_tool";
431        let matches = self
432            .frontend_tools
433            .iter()
434            .filter(|tool| tool.name.ends_with(suffix))
435            .filter(|tool| {
436                server
437                    .map(|name| tool.name == format!("{name}{suffix}"))
438                    .unwrap_or(true)
439            })
440            .map(|tool| tool.name.clone())
441            .collect::<Vec<_>>();
442        match matches.as_slice() {
443            [name] => Ok(name.clone()),
444            [] => Err(Error::Config(format!(
445                "No compressed invoke wrapper found for server {}",
446                server.unwrap_or("<default>")
447            ))),
448            _ => Err(Error::Config(
449                "Multiple compressed invoke wrappers found; specify a server".to_string(),
450            )),
451        }
452    }
453}
454
455#[async_trait]
456pub trait ExecutableTool: Send + Sync {
457    fn name(&self) -> &str;
458    fn description(&self) -> Option<&str>;
459    fn input_schema(&self) -> &Value;
460    async fn execute(&self, input: Value) -> Result<String, Error>;
461}
462
463struct ProxyExecutableTool<'a> {
464    proxy: &'a CompressorProxy,
465    tool: Tool,
466}
467
468#[async_trait]
469impl ExecutableTool for ProxyExecutableTool<'_> {
470    fn name(&self) -> &str {
471        &self.tool.name
472    }
473
474    fn description(&self) -> Option<&str> {
475        self.tool.description.as_deref()
476    }
477
478    fn input_schema(&self) -> &Value {
479        &self.tool.input_schema
480    }
481
482    async fn execute(&self, input: Value) -> Result<String, Error> {
483        self.proxy.invoke_wrapper_tool(&self.tool.name, input).await
484    }
485}
486
487#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488pub enum CodeLanguage {
489    Python,
490    TypeScript,
491}
492
493#[derive(Debug, Clone, PartialEq, Eq)]
494pub struct GeneratedClient {
495    pub kind: GeneratedClientKind,
496    pub output_dir: PathBuf,
497    pub files: Vec<PathBuf>,
498    pub environment: HashMap<String, String>,
499}
500
501#[derive(Debug, Clone, Copy, PartialEq, Eq)]
502pub enum GeneratedClientKind {
503    Cli,
504    Python,
505    TypeScript,
506}
507
508impl GeneratedClientKind {
509    fn environment(self, config: &GeneratorConfig) -> HashMap<String, String> {
510        match self {
511            GeneratedClientKind::Python => HashMap::from([(
512                "PYTHONPATH".to_string(),
513                config.output_dir.to_string_lossy().to_string(),
514            )]),
515            GeneratedClientKind::Cli | GeneratedClientKind::TypeScript => HashMap::new(),
516        }
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use serde_json::json;
523
524    use super::*;
525
526    fn fixture_path(name: &str) -> String {
527        format!("{}/tests/fixtures/{name}", env!("CARGO_MANIFEST_DIR"))
528    }
529
530    fn python_command() -> String {
531        std::env::var("PYTHON").unwrap_or_else(|_| "python3".to_string())
532    }
533
534    #[test]
535    fn server_config_oauth_app_name_is_preserved_for_transport_layer() {
536        let config = ServerConfig::url("https://example.test/mcp")
537            .oauth_app_name("Rovo Dev")
538            .materialize()
539            .0;
540
541        match config {
542            FfiSdkServerConfig::Structured { oauth_app_name, .. } => {
543                assert_eq!(oauth_app_name.as_deref(), Some("Rovo Dev"));
544            }
545            FfiSdkServerConfig::CommandOrUrl(_) => panic!("expected structured config"),
546        }
547    }
548
549    #[test]
550    fn server_config_auth_provider_is_preserved_for_transport_layer() {
551        let (config, provider) = ServerConfig::url("https://example.test/mcp")
552            .header("X-Static", "yes")
553            .auth_provider(|| {
554                Ok(BTreeMap::from([(
555                    "Authorization".to_string(),
556                    "Bearer dynamic".to_string(),
557                )]))
558            })
559            .materialize();
560
561        let backends = normalize_sdk_servers(FfiSdkServersConfig::from_iter([(
562            "remote".to_string(),
563            config,
564        )]))
565        .unwrap();
566
567        assert_eq!(backends[0].command_or_url, "https://example.test/mcp");
568        assert_eq!(
569            backends[0].args,
570            ["-H", "X-Static=yes", "--auth", "explicit-headers"]
571        );
572        assert!(provider.is_some());
573    }
574
575    #[tokio::test]
576    async fn compressor_client_invokes_single_server_without_compressor_subprocess() {
577        let client = CompressorClient::builder()
578            .server(
579                "alpha",
580                ServerConfig::command(python_command()).arg(fixture_path("alpha_server.py")),
581            )
582            .compression_level(CompressionLevel::Max)
583            .build();
584        let proxy = client.connect().await.unwrap();
585        assert!(proxy
586            .tools()
587            .iter()
588            .any(|tool| tool.name == "alpha_invoke_tool"));
589        let result = proxy
590            .invoke("echo", json!({ "message": "rust-sdk" }))
591            .await
592            .unwrap();
593        assert_eq!(result, "alpha:rust-sdk");
594
595        let executable = proxy.executable_tools();
596        let invoke = executable.get("alpha_invoke_tool").unwrap();
597        let executable_result = invoke
598            .execute(json!({
599                "tool_name": "echo",
600                "tool_input": { "message": "executable-rust" }
601            }))
602            .await
603            .unwrap();
604        assert_eq!(executable_result, "alpha:executable-rust");
605    }
606
607    #[tokio::test]
608    async fn compressor_client_routes_multiple_servers() {
609        let client = CompressorClient::builder()
610            .server(
611                "alpha",
612                ServerConfig::command(python_command()).arg(fixture_path("alpha_server.py")),
613            )
614            .server(
615                "beta",
616                ServerConfig::command(python_command()).arg(fixture_path("beta_server.py")),
617            )
618            .compression_level(CompressionLevel::Max)
619            .build();
620        let proxy = client.connect().await.unwrap();
621        let alpha = proxy
622            .invoke_on(Some("alpha"), "add", json!({ "a": 2, "b": 3 }))
623            .await
624            .unwrap();
625        let beta = proxy
626            .invoke_on(Some("beta"), "multiply", json!({ "a": 4, "b": 5 }))
627            .await
628            .unwrap();
629        assert_eq!(alpha, "5");
630        assert_eq!(beta, "20");
631    }
632
633    #[tokio::test]
634    async fn compressor_client_writes_generated_clients() {
635        let client = CompressorClient::builder()
636            .server(
637                "alpha",
638                ServerConfig::command(python_command()).arg(fixture_path("alpha_server.py")),
639            )
640            .compression_level(CompressionLevel::Max)
641            .build();
642        let proxy = client.connect().await.unwrap();
643        let tempdir = tempfile::tempdir().unwrap();
644        let generated = proxy
645            .write_code_client(CodeLanguage::Python, tempdir.path(), Some("alpha"))
646            .unwrap();
647        assert_eq!(generated.kind, GeneratedClientKind::Python);
648        assert!(generated.files.iter().any(|path| path.ends_with("alpha.py")));
649        assert_eq!(
650            generated.environment.get("PYTHONPATH"),
651            Some(&tempdir.path().to_string_lossy().to_string())
652        );
653
654        let cli = proxy.write_cli_client(tempdir.path(), Some("alpha")).unwrap();
655        assert_eq!(cli.kind, GeneratedClientKind::Cli);
656    }
657
658    #[tokio::test]
659    async fn compressor_client_exposes_cli_and_just_bash_modes() {
660        let cli = CompressorClient::builder()
661            .server(
662                "alpha",
663                ServerConfig::command(python_command()).arg(fixture_path("alpha_server.py")),
664            )
665            .mode(CompressorMode::Cli)
666            .build()
667            .connect()
668            .await
669            .unwrap();
670        assert!(cli.tools().iter().any(|tool| tool.name == "alpha_help"));
671
672        let bash = CompressorClient::builder()
673            .server(
674                "alpha",
675                ServerConfig::command(python_command()).arg(fixture_path("alpha_server.py")),
676            )
677            .mode(CompressorMode::JustBash)
678            .build()
679            .connect()
680            .await
681            .unwrap();
682        assert!(bash.tools().iter().any(|tool| tool.name == "bash_tool"));
683        assert!(bash.tools().iter().any(|tool| tool.name == "alpha_help"));
684        let provider = bash
685            .just_bash_providers()
686            .iter()
687            .find(|provider| provider.provider_name == "alpha")
688            .unwrap();
689        assert_eq!(provider.help_tool_name, "alpha_help");
690        assert!(provider.tools.iter().any(|command| {
691            command.command_name == "echo"
692                && command.backend_tool_name == "echo"
693                && command.invoke_tool_name == "alpha_invoke_tool"
694        }));
695    }
696}