Skip to main content

adk_rust_mcp_image/
server.rs

1//! MCP Server implementation for the Image server.
2//!
3//! This module provides the MCP server handler that exposes:
4//! - `image_generate` tool for text-to-image generation
5//! - `image_upscale` tool for image upscaling
6//! - Resources for models, segmentation classes, and providers
7
8use crate::handler::{ImageGenerateParams, ImageGenerateResult, ImageHandler, ImageUpscaleParams, ImageUpscaleResult};
9use crate::resources;
10use adk_rust_mcp_common::config::Config;
11use adk_rust_mcp_common::error::Error;
12use rmcp::{
13    model::{
14        CallToolResult, Content, ListResourcesResult, ReadResourceResult,
15        ResourceContents, ServerCapabilities, ServerInfo,
16    },
17    ErrorData as McpError, ServerHandler,
18};
19use schemars::JsonSchema;
20use serde::Deserialize;
21use std::borrow::Cow;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24use tracing::{debug, info};
25
26/// MCP Server for image generation.
27#[derive(Clone)]
28pub struct ImageServer {
29    /// Handler for image generation operations
30    handler: Arc<RwLock<Option<ImageHandler>>>,
31    /// Server configuration
32    config: Config,
33}
34
35/// Tool parameters wrapper for image_generate.
36#[derive(Debug, Deserialize, JsonSchema)]
37pub struct ImageGenerateToolParams {
38    /// Text prompt describing the image to generate
39    pub prompt: String,
40    /// Negative prompt - what to avoid in the generated image
41    #[serde(default)]
42    pub negative_prompt: Option<String>,
43    /// Model to use for generation (default: imagen-4.0-generate-preview-05-20)
44    #[serde(default)]
45    pub model: Option<String>,
46    /// Aspect ratio (1:1, 3:4, 4:3, 9:16, 16:9)
47    #[serde(default)]
48    pub aspect_ratio: Option<String>,
49    /// Number of images to generate (1-4)
50    #[serde(default)]
51    pub number_of_images: Option<u8>,
52    /// Random seed for reproducibility
53    #[serde(default)]
54    pub seed: Option<i64>,
55    /// Output file path for saving locally
56    #[serde(default)]
57    pub output_file: Option<String>,
58    /// Output storage URI (e.g., gs://bucket/path)
59    #[serde(default)]
60    pub output_uri: Option<String>,
61}
62
63impl From<ImageGenerateToolParams> for ImageGenerateParams {
64    fn from(params: ImageGenerateToolParams) -> Self {
65        Self {
66            prompt: params.prompt,
67            negative_prompt: params.negative_prompt,
68            model: params.model.unwrap_or_else(|| crate::handler::DEFAULT_MODEL.to_string()),
69            aspect_ratio: params.aspect_ratio.unwrap_or_else(|| "1:1".to_string()),
70            number_of_images: params.number_of_images.unwrap_or(1),
71            seed: params.seed,
72            output_file: params.output_file,
73            output_uri: params.output_uri,
74        }
75    }
76}
77
78/// Tool parameters wrapper for image_upscale.
79#[derive(Debug, Deserialize, JsonSchema)]
80pub struct ImageUpscaleToolParams {
81    /// Source image to upscale (base64 data, local path, or GCS URI)
82    pub image: String,
83    /// Upscale factor: "x2" or "x4" (default: "x2")
84    #[serde(default)]
85    pub upscale_factor: Option<String>,
86    /// Output file path for saving locally
87    #[serde(default)]
88    pub output_file: Option<String>,
89    /// Output storage URI (e.g., gs://bucket/path)
90    #[serde(default)]
91    pub output_uri: Option<String>,
92}
93
94impl From<ImageUpscaleToolParams> for ImageUpscaleParams {
95    fn from(params: ImageUpscaleToolParams) -> Self {
96        Self {
97            image: params.image,
98            upscale_factor: params.upscale_factor.unwrap_or_else(|| "x2".to_string()),
99            output_file: params.output_file,
100            output_uri: params.output_uri,
101        }
102    }
103}
104
105impl ImageServer {
106    /// Create a new ImageServer with the given configuration.
107    pub fn new(config: Config) -> Self {
108        Self {
109            handler: Arc::new(RwLock::new(None)),
110            config,
111        }
112    }
113
114    /// Initialize the handler (called lazily on first use).
115    async fn ensure_handler(&self) -> Result<(), Error> {
116        let mut handler = self.handler.write().await;
117        if handler.is_none() {
118            *handler = Some(ImageHandler::new(self.config.clone()).await?);
119        }
120        Ok(())
121    }
122
123    /// Generate images from a text prompt.
124    pub async fn generate_image(&self, params: ImageGenerateToolParams) -> Result<CallToolResult, McpError> {
125        info!(prompt = %params.prompt, "Generating image");
126
127        // Ensure handler is initialized
128        self.ensure_handler().await.map_err(|e| {
129            McpError::internal_error(format!("Failed to initialize handler: {}", e), None)
130        })?;
131
132        let handler_guard = self.handler.read().await;
133        let handler = handler_guard.as_ref().ok_or_else(|| {
134            McpError::internal_error("Handler not initialized", None)
135        })?;
136
137        let gen_params: ImageGenerateParams = params.into();
138        let result = handler.generate_image(gen_params).await.map_err(|e| {
139            McpError::internal_error(format!("Image generation failed: {}", e), None)
140        })?;
141
142        // Convert result to MCP content
143        let content = match result {
144            ImageGenerateResult::Base64(images) => {
145                images
146                    .into_iter()
147                    .map(|img| Content::image(img.data, img.mime_type))
148                    .collect()
149            }
150            ImageGenerateResult::LocalFiles(paths) => {
151                vec![Content::text(format!("Images saved to: {}", paths.join(", ")))]
152            }
153            ImageGenerateResult::StorageUris(uris) => {
154                vec![Content::text(format!("Images uploaded to: {}", uris.join(", ")))]
155            }
156        };
157
158        Ok(CallToolResult::success(content))
159    }
160
161    /// Upscale an image.
162    pub async fn upscale_image(&self, params: ImageUpscaleToolParams) -> Result<CallToolResult, McpError> {
163        info!(upscale_factor = ?params.upscale_factor, "Upscaling image");
164
165        // Ensure handler is initialized
166        self.ensure_handler().await.map_err(|e| {
167            McpError::internal_error(format!("Failed to initialize handler: {}", e), None)
168        })?;
169
170        let handler_guard = self.handler.read().await;
171        let handler = handler_guard.as_ref().ok_or_else(|| {
172            McpError::internal_error("Handler not initialized", None)
173        })?;
174
175        let upscale_params: ImageUpscaleParams = params.into();
176        let result = handler.upscale_image(upscale_params).await.map_err(|e| {
177            McpError::internal_error(format!("Image upscaling failed: {}", e), None)
178        })?;
179
180        // Convert result to MCP content
181        let content = match result {
182            ImageUpscaleResult::Base64(image) => {
183                vec![Content::image(image.data, image.mime_type)]
184            }
185            ImageUpscaleResult::LocalFile(path) => {
186                vec![Content::text(format!("Upscaled image saved to: {}", path))]
187            }
188            ImageUpscaleResult::StorageUri(uri) => {
189                vec![Content::text(format!("Upscaled image uploaded to: {}", uri))]
190            }
191        };
192
193        Ok(CallToolResult::success(content))
194    }
195}
196
197impl ServerHandler for ImageServer {
198    fn get_info(&self) -> ServerInfo {
199        ServerInfo {
200            instructions: Some(
201                "Image generation and processing server using Google Vertex AI Imagen API. \
202                 Use image_generate to create images from text prompts, \
203                 and image_upscale to upscale existing images."
204                    .to_string(),
205            ),
206            capabilities: ServerCapabilities::builder()
207                .enable_tools()
208                .enable_resources()
209                .build(),
210            ..Default::default()
211        }
212    }
213
214    fn list_tools(
215        &self,
216        _params: Option<rmcp::model::PaginatedRequestParams>,
217        _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
218    ) -> impl std::future::Future<Output = Result<rmcp::model::ListToolsResult, McpError>> + Send + '_ {
219        async move {
220            use rmcp::model::{ListToolsResult, Tool};
221            use schemars::schema_for;
222
223            // image_generate tool
224            let gen_schema = schema_for!(ImageGenerateToolParams);
225            let gen_schema_value = serde_json::to_value(&gen_schema).unwrap_or_default();
226            let gen_input_schema = match gen_schema_value {
227                serde_json::Value::Object(map) => Arc::new(map),
228                _ => Arc::new(serde_json::Map::new()),
229            };
230
231            // image_upscale tool
232            let upscale_schema = schema_for!(ImageUpscaleToolParams);
233            let upscale_schema_value = serde_json::to_value(&upscale_schema).unwrap_or_default();
234            let upscale_input_schema = match upscale_schema_value {
235                serde_json::Value::Object(map) => Arc::new(map),
236                _ => Arc::new(serde_json::Map::new()),
237            };
238
239            Ok(ListToolsResult {
240                tools: vec![
241                    Tool {
242                        name: Cow::Borrowed("image_generate"),
243                        description: Some(Cow::Borrowed(
244                            "Generate images from a text prompt using Google's Imagen API. \
245                             Returns base64-encoded image data, local file paths, or storage URIs \
246                             depending on output parameters."
247                        )),
248                        input_schema: gen_input_schema,
249                        annotations: None,
250                        icons: None,
251                        meta: None,
252                        output_schema: None,
253                        title: None,
254                    },
255                    Tool {
256                        name: Cow::Borrowed("image_upscale"),
257                        description: Some(Cow::Borrowed(
258                            "Upscale an image using Google's Imagen 4.0 Upscale API. \
259                             Supports x2 and x4 upscale factors. \
260                             Accepts base64 image data, local file path, or GCS URI as input. \
261                             Returns base64-encoded image data, local file path, or storage URI."
262                        )),
263                        input_schema: upscale_input_schema,
264                        annotations: None,
265                        icons: None,
266                        meta: None,
267                        output_schema: None,
268                        title: None,
269                    },
270                ],
271                next_cursor: None,
272                meta: None,
273            })
274        }
275    }
276
277    fn call_tool(
278        &self,
279        params: rmcp::model::CallToolRequestParams,
280        _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
281    ) -> impl std::future::Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
282        async move {
283            match params.name.as_ref() {
284                "image_generate" => {
285                    let tool_params: ImageGenerateToolParams = params
286                        .arguments
287                        .map(|args| serde_json::from_value(serde_json::Value::Object(args)))
288                        .transpose()
289                        .map_err(|e| McpError::invalid_params(format!("Invalid parameters: {}", e), None))?
290                        .ok_or_else(|| McpError::invalid_params("Missing parameters", None))?;
291
292                    self.generate_image(tool_params).await
293                }
294                "image_upscale" => {
295                    let tool_params: ImageUpscaleToolParams = params
296                        .arguments
297                        .map(|args| serde_json::from_value(serde_json::Value::Object(args)))
298                        .transpose()
299                        .map_err(|e| McpError::invalid_params(format!("Invalid parameters: {}", e), None))?
300                        .ok_or_else(|| McpError::invalid_params("Missing parameters", None))?;
301
302                    self.upscale_image(tool_params).await
303                }
304                _ => Err(McpError::invalid_params(format!("Unknown tool: {}", params.name), None)),
305            }
306        }
307    }
308
309    fn list_resources(
310        &self,
311        _params: Option<rmcp::model::PaginatedRequestParams>,
312        _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
313    ) -> impl std::future::Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
314        async move {
315            debug!("Listing resources");
316            
317            // Build resources using the raw struct approach
318            let models_resource = rmcp::model::Resource {
319                raw: rmcp::model::RawResource {
320                    uri: "image://models".to_string(),
321                    name: "Available Image Models".to_string(),
322                    title: None,
323                    description: Some("List of available image generation models".to_string()),
324                    mime_type: Some("application/json".to_string()),
325                    size: None,
326                    icons: None,
327                    meta: None,
328                },
329                annotations: None,
330            };
331
332            let segmentation_resource = rmcp::model::Resource {
333                raw: rmcp::model::RawResource {
334                    uri: "image://segmentation_classes".to_string(),
335                    name: "Segmentation Classes".to_string(),
336                    title: None,
337                    description: Some("List of segmentation classes for image editing (Google provider)".to_string()),
338                    mime_type: Some("application/json".to_string()),
339                    size: None,
340                    icons: None,
341                    meta: None,
342                },
343                annotations: None,
344            };
345
346            let providers_resource = rmcp::model::Resource {
347                raw: rmcp::model::RawResource {
348                    uri: "image://providers".to_string(),
349                    name: "Available Providers".to_string(),
350                    title: None,
351                    description: Some("List of available image generation providers".to_string()),
352                    mime_type: Some("application/json".to_string()),
353                    size: None,
354                    icons: None,
355                    meta: None,
356                },
357                annotations: None,
358            };
359
360            Ok(ListResourcesResult {
361                resources: vec![models_resource, segmentation_resource, providers_resource],
362                next_cursor: None,
363                meta: None,
364            })
365        }
366    }
367
368    fn read_resource(
369        &self,
370        params: rmcp::model::ReadResourceRequestParams,
371        _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
372    ) -> impl std::future::Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
373        async move {
374            let uri = &params.uri;
375            debug!(uri = %uri, "Reading resource");
376
377            let content = match uri.as_str() {
378                "image://models" => resources::models_resource_json(),
379                "image://segmentation_classes" => resources::segmentation_classes_resource_json(),
380                "image://providers" => resources::providers_resource_json(),
381                _ => {
382                    return Err(McpError::resource_not_found(
383                        format!("Unknown resource: {}", uri),
384                        None,
385                    ));
386                }
387            };
388
389            Ok(ReadResourceResult {
390                contents: vec![ResourceContents::text(content, uri.clone())],
391            })
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    fn test_config() -> Config {
401        Config {
402            project_id: "test-project".to_string(),
403            location: "us-central1".to_string(),
404            gcs_bucket: None,
405            port: 8080,
406        ..Default::default()
407        }
408    }
409
410    #[test]
411    fn test_server_info() {
412        let server = ImageServer::new(test_config());
413        let info = server.get_info();
414        assert!(info.instructions.is_some());
415    }
416
417    #[test]
418    fn test_tool_params_conversion() {
419        let tool_params = ImageGenerateToolParams {
420            prompt: "A cat".to_string(),
421            negative_prompt: Some("blurry".to_string()),
422            model: Some("imagen-4".to_string()),
423            aspect_ratio: Some("16:9".to_string()),
424            number_of_images: Some(2),
425            seed: Some(42),
426            output_file: None,
427            output_uri: None,
428        };
429
430        let gen_params: ImageGenerateParams = tool_params.into();
431        assert_eq!(gen_params.prompt, "A cat");
432        assert_eq!(gen_params.negative_prompt, Some("blurry".to_string()));
433        assert_eq!(gen_params.model, "imagen-4");
434        assert_eq!(gen_params.aspect_ratio, "16:9");
435        assert_eq!(gen_params.number_of_images, 2);
436        assert_eq!(gen_params.seed, Some(42));
437    }
438
439    #[test]
440    fn test_tool_params_defaults() {
441        let tool_params = ImageGenerateToolParams {
442            prompt: "A cat".to_string(),
443            negative_prompt: None,
444            model: None,
445            aspect_ratio: None,
446            number_of_images: None,
447            seed: None,
448            output_file: None,
449            output_uri: None,
450        };
451
452        let gen_params: ImageGenerateParams = tool_params.into();
453        assert_eq!(gen_params.model, crate::handler::DEFAULT_MODEL);
454        assert_eq!(gen_params.aspect_ratio, "1:1");
455        assert_eq!(gen_params.number_of_images, 1);
456    }
457}