1use crate::handler::{ImageGenerateParams, ImageGenerateResult, ImageHandler, ImageUpscaleParams, ImageUpscaleResult};
9use crate::resources;
10use adk_rust_mcp_common::config::Config;
11use adk_rust_mcp_common::error::Error;
12use rmcp::{
13 model::{
14 CallToolResult, Content, ListResourcesResult, ReadResourceResult,
15 ResourceContents, ServerCapabilities, ServerInfo,
16 },
17 ErrorData as McpError, ServerHandler,
18};
19use schemars::JsonSchema;
20use serde::Deserialize;
21use std::borrow::Cow;
22use std::sync::Arc;
23use tokio::sync::RwLock;
24use tracing::{debug, info};
25
26#[derive(Clone)]
28pub struct ImageServer {
29 handler: Arc<RwLock<Option<ImageHandler>>>,
31 config: Config,
33}
34
35#[derive(Debug, Deserialize, JsonSchema)]
37pub struct ImageGenerateToolParams {
38 pub prompt: String,
40 #[serde(default)]
42 pub negative_prompt: Option<String>,
43 #[serde(default)]
45 pub model: Option<String>,
46 #[serde(default)]
48 pub aspect_ratio: Option<String>,
49 #[serde(default)]
51 pub number_of_images: Option<u8>,
52 #[serde(default)]
54 pub seed: Option<i64>,
55 #[serde(default)]
57 pub output_file: Option<String>,
58 #[serde(default)]
60 pub output_uri: Option<String>,
61}
62
63impl From<ImageGenerateToolParams> for ImageGenerateParams {
64 fn from(params: ImageGenerateToolParams) -> Self {
65 Self {
66 prompt: params.prompt,
67 negative_prompt: params.negative_prompt,
68 model: params.model.unwrap_or_else(|| crate::handler::DEFAULT_MODEL.to_string()),
69 aspect_ratio: params.aspect_ratio.unwrap_or_else(|| "1:1".to_string()),
70 number_of_images: params.number_of_images.unwrap_or(1),
71 seed: params.seed,
72 output_file: params.output_file,
73 output_uri: params.output_uri,
74 }
75 }
76}
77
78#[derive(Debug, Deserialize, JsonSchema)]
80pub struct ImageUpscaleToolParams {
81 pub image: String,
83 #[serde(default)]
85 pub upscale_factor: Option<String>,
86 #[serde(default)]
88 pub output_file: Option<String>,
89 #[serde(default)]
91 pub output_uri: Option<String>,
92}
93
94impl From<ImageUpscaleToolParams> for ImageUpscaleParams {
95 fn from(params: ImageUpscaleToolParams) -> Self {
96 Self {
97 image: params.image,
98 upscale_factor: params.upscale_factor.unwrap_or_else(|| "x2".to_string()),
99 output_file: params.output_file,
100 output_uri: params.output_uri,
101 }
102 }
103}
104
105impl ImageServer {
106 pub fn new(config: Config) -> Self {
108 Self {
109 handler: Arc::new(RwLock::new(None)),
110 config,
111 }
112 }
113
114 async fn ensure_handler(&self) -> Result<(), Error> {
116 let mut handler = self.handler.write().await;
117 if handler.is_none() {
118 *handler = Some(ImageHandler::new(self.config.clone()).await?);
119 }
120 Ok(())
121 }
122
123 pub async fn generate_image(&self, params: ImageGenerateToolParams) -> Result<CallToolResult, McpError> {
125 info!(prompt = %params.prompt, "Generating image");
126
127 self.ensure_handler().await.map_err(|e| {
129 McpError::internal_error(format!("Failed to initialize handler: {}", e), None)
130 })?;
131
132 let handler_guard = self.handler.read().await;
133 let handler = handler_guard.as_ref().ok_or_else(|| {
134 McpError::internal_error("Handler not initialized", None)
135 })?;
136
137 let gen_params: ImageGenerateParams = params.into();
138 let result = handler.generate_image(gen_params).await.map_err(|e| {
139 McpError::internal_error(format!("Image generation failed: {}", e), None)
140 })?;
141
142 let content = match result {
144 ImageGenerateResult::Base64(images) => {
145 images
146 .into_iter()
147 .map(|img| Content::image(img.data, img.mime_type))
148 .collect()
149 }
150 ImageGenerateResult::LocalFiles(paths) => {
151 vec![Content::text(format!("Images saved to: {}", paths.join(", ")))]
152 }
153 ImageGenerateResult::StorageUris(uris) => {
154 vec![Content::text(format!("Images uploaded to: {}", uris.join(", ")))]
155 }
156 };
157
158 Ok(CallToolResult::success(content))
159 }
160
161 pub async fn upscale_image(&self, params: ImageUpscaleToolParams) -> Result<CallToolResult, McpError> {
163 info!(upscale_factor = ?params.upscale_factor, "Upscaling image");
164
165 self.ensure_handler().await.map_err(|e| {
167 McpError::internal_error(format!("Failed to initialize handler: {}", e), None)
168 })?;
169
170 let handler_guard = self.handler.read().await;
171 let handler = handler_guard.as_ref().ok_or_else(|| {
172 McpError::internal_error("Handler not initialized", None)
173 })?;
174
175 let upscale_params: ImageUpscaleParams = params.into();
176 let result = handler.upscale_image(upscale_params).await.map_err(|e| {
177 McpError::internal_error(format!("Image upscaling failed: {}", e), None)
178 })?;
179
180 let content = match result {
182 ImageUpscaleResult::Base64(image) => {
183 vec![Content::image(image.data, image.mime_type)]
184 }
185 ImageUpscaleResult::LocalFile(path) => {
186 vec![Content::text(format!("Upscaled image saved to: {}", path))]
187 }
188 ImageUpscaleResult::StorageUri(uri) => {
189 vec![Content::text(format!("Upscaled image uploaded to: {}", uri))]
190 }
191 };
192
193 Ok(CallToolResult::success(content))
194 }
195}
196
197impl ServerHandler for ImageServer {
198 fn get_info(&self) -> ServerInfo {
199 ServerInfo {
200 instructions: Some(
201 "Image generation and processing server using Google Vertex AI Imagen API. \
202 Use image_generate to create images from text prompts, \
203 and image_upscale to upscale existing images."
204 .to_string(),
205 ),
206 capabilities: ServerCapabilities::builder()
207 .enable_tools()
208 .enable_resources()
209 .build(),
210 ..Default::default()
211 }
212 }
213
214 fn list_tools(
215 &self,
216 _params: Option<rmcp::model::PaginatedRequestParams>,
217 _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
218 ) -> impl std::future::Future<Output = Result<rmcp::model::ListToolsResult, McpError>> + Send + '_ {
219 async move {
220 use rmcp::model::{ListToolsResult, Tool};
221 use schemars::schema_for;
222
223 let gen_schema = schema_for!(ImageGenerateToolParams);
225 let gen_schema_value = serde_json::to_value(&gen_schema).unwrap_or_default();
226 let gen_input_schema = match gen_schema_value {
227 serde_json::Value::Object(map) => Arc::new(map),
228 _ => Arc::new(serde_json::Map::new()),
229 };
230
231 let upscale_schema = schema_for!(ImageUpscaleToolParams);
233 let upscale_schema_value = serde_json::to_value(&upscale_schema).unwrap_or_default();
234 let upscale_input_schema = match upscale_schema_value {
235 serde_json::Value::Object(map) => Arc::new(map),
236 _ => Arc::new(serde_json::Map::new()),
237 };
238
239 Ok(ListToolsResult {
240 tools: vec![
241 Tool {
242 name: Cow::Borrowed("image_generate"),
243 description: Some(Cow::Borrowed(
244 "Generate images from a text prompt using Google's Imagen API. \
245 Returns base64-encoded image data, local file paths, or storage URIs \
246 depending on output parameters."
247 )),
248 input_schema: gen_input_schema,
249 annotations: None,
250 icons: None,
251 meta: None,
252 output_schema: None,
253 title: None,
254 },
255 Tool {
256 name: Cow::Borrowed("image_upscale"),
257 description: Some(Cow::Borrowed(
258 "Upscale an image using Google's Imagen 4.0 Upscale API. \
259 Supports x2 and x4 upscale factors. \
260 Accepts base64 image data, local file path, or GCS URI as input. \
261 Returns base64-encoded image data, local file path, or storage URI."
262 )),
263 input_schema: upscale_input_schema,
264 annotations: None,
265 icons: None,
266 meta: None,
267 output_schema: None,
268 title: None,
269 },
270 ],
271 next_cursor: None,
272 meta: None,
273 })
274 }
275 }
276
277 fn call_tool(
278 &self,
279 params: rmcp::model::CallToolRequestParams,
280 _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
281 ) -> impl std::future::Future<Output = Result<CallToolResult, McpError>> + Send + '_ {
282 async move {
283 match params.name.as_ref() {
284 "image_generate" => {
285 let tool_params: ImageGenerateToolParams = params
286 .arguments
287 .map(|args| serde_json::from_value(serde_json::Value::Object(args)))
288 .transpose()
289 .map_err(|e| McpError::invalid_params(format!("Invalid parameters: {}", e), None))?
290 .ok_or_else(|| McpError::invalid_params("Missing parameters", None))?;
291
292 self.generate_image(tool_params).await
293 }
294 "image_upscale" => {
295 let tool_params: ImageUpscaleToolParams = params
296 .arguments
297 .map(|args| serde_json::from_value(serde_json::Value::Object(args)))
298 .transpose()
299 .map_err(|e| McpError::invalid_params(format!("Invalid parameters: {}", e), None))?
300 .ok_or_else(|| McpError::invalid_params("Missing parameters", None))?;
301
302 self.upscale_image(tool_params).await
303 }
304 _ => Err(McpError::invalid_params(format!("Unknown tool: {}", params.name), None)),
305 }
306 }
307 }
308
309 fn list_resources(
310 &self,
311 _params: Option<rmcp::model::PaginatedRequestParams>,
312 _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
313 ) -> impl std::future::Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ {
314 async move {
315 debug!("Listing resources");
316
317 let models_resource = rmcp::model::Resource {
319 raw: rmcp::model::RawResource {
320 uri: "image://models".to_string(),
321 name: "Available Image Models".to_string(),
322 title: None,
323 description: Some("List of available image generation models".to_string()),
324 mime_type: Some("application/json".to_string()),
325 size: None,
326 icons: None,
327 meta: None,
328 },
329 annotations: None,
330 };
331
332 let segmentation_resource = rmcp::model::Resource {
333 raw: rmcp::model::RawResource {
334 uri: "image://segmentation_classes".to_string(),
335 name: "Segmentation Classes".to_string(),
336 title: None,
337 description: Some("List of segmentation classes for image editing (Google provider)".to_string()),
338 mime_type: Some("application/json".to_string()),
339 size: None,
340 icons: None,
341 meta: None,
342 },
343 annotations: None,
344 };
345
346 let providers_resource = rmcp::model::Resource {
347 raw: rmcp::model::RawResource {
348 uri: "image://providers".to_string(),
349 name: "Available Providers".to_string(),
350 title: None,
351 description: Some("List of available image generation providers".to_string()),
352 mime_type: Some("application/json".to_string()),
353 size: None,
354 icons: None,
355 meta: None,
356 },
357 annotations: None,
358 };
359
360 Ok(ListResourcesResult {
361 resources: vec![models_resource, segmentation_resource, providers_resource],
362 next_cursor: None,
363 meta: None,
364 })
365 }
366 }
367
368 fn read_resource(
369 &self,
370 params: rmcp::model::ReadResourceRequestParams,
371 _context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
372 ) -> impl std::future::Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ {
373 async move {
374 let uri = ¶ms.uri;
375 debug!(uri = %uri, "Reading resource");
376
377 let content = match uri.as_str() {
378 "image://models" => resources::models_resource_json(),
379 "image://segmentation_classes" => resources::segmentation_classes_resource_json(),
380 "image://providers" => resources::providers_resource_json(),
381 _ => {
382 return Err(McpError::resource_not_found(
383 format!("Unknown resource: {}", uri),
384 None,
385 ));
386 }
387 };
388
389 Ok(ReadResourceResult {
390 contents: vec![ResourceContents::text(content, uri.clone())],
391 })
392 }
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 fn test_config() -> Config {
401 Config {
402 project_id: "test-project".to_string(),
403 location: "us-central1".to_string(),
404 gcs_bucket: None,
405 port: 8080,
406 ..Default::default()
407 }
408 }
409
410 #[test]
411 fn test_server_info() {
412 let server = ImageServer::new(test_config());
413 let info = server.get_info();
414 assert!(info.instructions.is_some());
415 }
416
417 #[test]
418 fn test_tool_params_conversion() {
419 let tool_params = ImageGenerateToolParams {
420 prompt: "A cat".to_string(),
421 negative_prompt: Some("blurry".to_string()),
422 model: Some("imagen-4".to_string()),
423 aspect_ratio: Some("16:9".to_string()),
424 number_of_images: Some(2),
425 seed: Some(42),
426 output_file: None,
427 output_uri: None,
428 };
429
430 let gen_params: ImageGenerateParams = tool_params.into();
431 assert_eq!(gen_params.prompt, "A cat");
432 assert_eq!(gen_params.negative_prompt, Some("blurry".to_string()));
433 assert_eq!(gen_params.model, "imagen-4");
434 assert_eq!(gen_params.aspect_ratio, "16:9");
435 assert_eq!(gen_params.number_of_images, 2);
436 assert_eq!(gen_params.seed, Some(42));
437 }
438
439 #[test]
440 fn test_tool_params_defaults() {
441 let tool_params = ImageGenerateToolParams {
442 prompt: "A cat".to_string(),
443 negative_prompt: None,
444 model: None,
445 aspect_ratio: None,
446 number_of_images: None,
447 seed: None,
448 output_file: None,
449 output_uri: None,
450 };
451
452 let gen_params: ImageGenerateParams = tool_params.into();
453 assert_eq!(gen_params.model, crate::handler::DEFAULT_MODEL);
454 assert_eq!(gen_params.aspect_ratio, "1:1");
455 assert_eq!(gen_params.number_of_images, 1);
456 }
457}