1use crate::client::ComposioClient;
8use crate::error::ComposioError;
9use crate::models::MetaToolSlug;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct WorkbenchResult {
16 pub output: String,
18
19 pub successful: bool,
21
22 #[serde(skip_serializing_if = "Option::is_none")]
24 pub error: Option<String>,
25
26 pub session_id: String,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub files: Option<Vec<String>>,
32}
33
34#[derive(Debug, Clone)]
36pub enum PandasOperation {
37 ReadCsv { url: String },
39
40 FilterRows { column: String, value: String },
42
43 GroupBy { column: String },
45
46 Aggregate { column: String, operation: String },
48
49 SortBy { column: String, ascending: bool },
51
52 Custom { code: String },
54}
55
56#[derive(Debug, Clone)]
58pub enum ExcelOperation {
59 Read { s3_url: String },
61
62 Edit {
64 s3_url: String,
65 operations: Vec<String>,
66 upload_tool: String,
67 file_path: String,
68 },
69
70 AddRows {
72 s3_url: String,
73 rows: Vec<Vec<String>>,
74 upload_tool: String,
75 file_path: String,
76 },
77}
78
79pub struct WorkbenchExecutor {
81 client: Arc<ComposioClient>,
82 session_id: String,
83}
84
85impl WorkbenchExecutor {
86 pub fn new(client: Arc<ComposioClient>, session_id: impl Into<String>) -> Self {
109 Self {
110 client,
111 session_id: session_id.into(),
112 }
113 }
114
115 pub async fn execute_python(&self, code: &str) -> Result<WorkbenchResult, ComposioError> {
146 self.validate_python_syntax(code)?;
148
149 let url = format!(
151 "{}/tool_router/session/{}/execute_meta",
152 self.client.config().base_url,
153 self.session_id
154 );
155
156 let response = self
157 .client
158 .http_client()
159 .post(&url)
160 .json(&serde_json::json!({
161 "tool_slug": MetaToolSlug::ComposioRemoteWorkbench,
162 "arguments": {
163 "code": code,
164 "session_id": self.session_id,
165 }
166 }))
167 .send()
168 .await?;
169
170 if !response.status().is_success() {
171 return Err(ComposioError::from_response(response).await);
172 }
173
174 let data: serde_json::Value = response.json().await?;
175
176 let result = WorkbenchResult {
178 output: data["data"]["output"]
179 .as_str()
180 .unwrap_or("")
181 .to_string(),
182 successful: data["data"]["successful"].as_bool().unwrap_or(false),
183 error: data["data"]["error"].as_str().map(|s| s.to_string()),
184 session_id: self.session_id.clone(),
185 files: data["data"]["files"]
186 .as_array()
187 .map(|arr| {
188 arr.iter()
189 .filter_map(|v| v.as_str().map(|s| s.to_string()))
190 .collect()
191 }),
192 };
193
194 Ok(result)
195 }
196
197 pub fn generate_pandas_code(&self, operation: PandasOperation) -> String {
225 match operation {
226 PandasOperation::ReadCsv { url } => {
227 format!(
228 r#"
229import pandas as pd
230import requests
231
232# Download CSV
233response = requests.get("{}")
234df = pd.read_csv(response.content)
235print(df.head())
236print(f"\nShape: {{df.shape}}")
237print(f"Columns: {{df.columns.tolist()}}")
238"#,
239 url
240 )
241 }
242 PandasOperation::FilterRows { column, value } => {
243 format!(
244 r#"
245# Filter dataframe
246filtered = df[df['{}'] == '{}']
247print(f"Found {{len(filtered)}} rows")
248print(filtered)
249"#,
250 column, value
251 )
252 }
253 PandasOperation::GroupBy { column } => {
254 format!(
255 r#"
256# Group by column
257grouped = df.groupby('{}')
258print(grouped.size())
259"#,
260 column
261 )
262 }
263 PandasOperation::Aggregate { column, operation } => {
264 format!(
265 r#"
266# Aggregate
267result = df['{}'].{}()
268print(f"{} of {}: {{result}}")
269"#,
270 column, operation, operation, column
271 )
272 }
273 PandasOperation::SortBy { column, ascending } => {
274 format!(
275 r#"
276# Sort by column
277sorted_df = df.sort_values('{}', ascending={})
278print(sorted_df.head())
279"#,
280 column, ascending
281 )
282 }
283 PandasOperation::Custom { code } => code,
284 }
285 }
286
287 pub fn generate_excel_code(&self, operation: ExcelOperation) -> String {
315 match operation {
316 ExcelOperation::Read { s3_url } => {
317 format!(
318 r#"
319import openpyxl
320import requests
321
322# Download Excel file
323response = requests.get('{}')
324with open('temp.xlsx', 'wb') as f:
325 f.write(response.content)
326
327# Load workbook
328wb = openpyxl.load_workbook('temp.xlsx')
329ws = wb.active
330
331# Print content
332print(f"Sheet: {{ws.title}}")
333print(f"Dimensions: {{ws.dimensions}}")
334print("\nFirst 10 rows:")
335for i, row in enumerate(ws.iter_rows(values_only=True), 1):
336 if i > 10:
337 break
338 print(row)
339"#,
340 s3_url
341 )
342 }
343 ExcelOperation::Edit {
344 s3_url,
345 operations,
346 upload_tool,
347 file_path,
348 } => {
349 let ops_code = operations.join("\n");
350 format!(
351 r#"
352import openpyxl
353import requests
354
355# Download existing file
356response = requests.get('{}')
357with open('temp.xlsx', 'wb') as f:
358 f.write(response.content)
359
360# Load and edit
361wb = openpyxl.load_workbook('temp.xlsx')
362ws = wb.active
363
364# Apply operations
365{}
366
367# Save
368wb.save('temp.xlsx')
369
370# Upload back
371with open('temp.xlsx', 'rb') as f:
372 result = run_composio_tool('{}', {{
373 'path': '{}',
374 'content': f.read()
375 }})
376print(result)
377"#,
378 s3_url, ops_code, upload_tool, file_path
379 )
380 }
381 ExcelOperation::AddRows {
382 s3_url,
383 rows,
384 upload_tool,
385 file_path,
386 } => {
387 let rows_code = rows
388 .iter()
389 .map(|row| format!("ws.append({:?})", row))
390 .collect::<Vec<_>>()
391 .join("\n");
392
393 format!(
394 r#"
395import openpyxl
396import requests
397
398# Download existing file
399response = requests.get('{}')
400with open('temp.xlsx', 'wb') as f:
401 f.write(response.content)
402
403# Load workbook
404wb = openpyxl.load_workbook('temp.xlsx')
405ws = wb.active
406
407# Add new rows
408{}
409
410# Save
411wb.save('temp.xlsx')
412
413# Upload back
414with open('temp.xlsx', 'rb') as f:
415 result = run_composio_tool('{}', {{
416 'path': '{}',
417 'content': f.read()
418 }})
419print(result)
420"#,
421 s3_url, rows_code, upload_tool, file_path
422 )
423 }
424 }
425 }
426
427 fn validate_python_syntax(&self, code: &str) -> Result<(), ComposioError> {
429 if code.trim().is_empty() {
431 return Err(ComposioError::ValidationError(
432 "Python code cannot be empty".to_string(),
433 ));
434 }
435
436 let mut paren_count = 0;
438 let mut bracket_count = 0;
439 let mut brace_count = 0;
440
441 for ch in code.chars() {
442 match ch {
443 '(' => paren_count += 1,
444 ')' => paren_count -= 1,
445 '[' => bracket_count += 1,
446 ']' => bracket_count -= 1,
447 '{' => brace_count += 1,
448 '}' => brace_count -= 1,
449 _ => {}
450 }
451 }
452
453 if paren_count != 0 {
454 return Err(ComposioError::ValidationError(
455 "Unbalanced parentheses in Python code".to_string(),
456 ));
457 }
458
459 if bracket_count != 0 {
460 return Err(ComposioError::ValidationError(
461 "Unbalanced brackets in Python code".to_string(),
462 ));
463 }
464
465 if brace_count != 0 {
466 return Err(ComposioError::ValidationError(
467 "Unbalanced braces in Python code".to_string(),
468 ));
469 }
470
471 Ok(())
472 }
473
474 pub fn session_id(&self) -> &str {
476 &self.session_id
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483
484 #[test]
485 fn test_pandas_read_csv_code_generation() {
486 let executor = WorkbenchExecutor::new(
487 Arc::new(ComposioClient::builder().api_key("test").build().unwrap()),
488 "session_123",
489 );
490
491 let code = executor.generate_pandas_code(PandasOperation::ReadCsv {
492 url: "https://example.com/data.csv".to_string(),
493 });
494
495 assert!(code.contains("import pandas as pd"));
496 assert!(code.contains("requests.get"));
497 assert!(code.contains("https://example.com/data.csv"));
498 assert!(code.contains("pd.read_csv"));
499 }
500
501 #[test]
502 fn test_pandas_filter_code_generation() {
503 let executor = WorkbenchExecutor::new(
504 Arc::new(ComposioClient::builder().api_key("test").build().unwrap()),
505 "session_123",
506 );
507
508 let code = executor.generate_pandas_code(PandasOperation::FilterRows {
509 column: "age".to_string(),
510 value: "25".to_string(),
511 });
512
513 assert!(code.contains("df['age']"));
514 assert!(code.contains("== '25'"));
515 }
516
517 #[test]
518 fn test_excel_read_code_generation() {
519 let executor = WorkbenchExecutor::new(
520 Arc::new(ComposioClient::builder().api_key("test").build().unwrap()),
521 "session_123",
522 );
523
524 let code = executor.generate_excel_code(ExcelOperation::Read {
525 s3_url: "https://s3.amazonaws.com/bucket/file.xlsx".to_string(),
526 });
527
528 assert!(code.contains("import openpyxl"));
529 assert!(code.contains("requests.get"));
530 assert!(code.contains("load_workbook"));
531 }
532
533 #[test]
534 fn test_python_syntax_validation_empty() {
535 let executor = WorkbenchExecutor::new(
536 Arc::new(ComposioClient::builder().api_key("test").build().unwrap()),
537 "session_123",
538 );
539
540 let result = executor.validate_python_syntax("");
541 assert!(result.is_err());
542 }
543
544 #[test]
545 fn test_python_syntax_validation_unbalanced_parens() {
546 let executor = WorkbenchExecutor::new(
547 Arc::new(ComposioClient::builder().api_key("test").build().unwrap()),
548 "session_123",
549 );
550
551 let result = executor.validate_python_syntax("print('hello'");
552 assert!(result.is_err());
553 assert!(result.unwrap_err().to_string().contains("parentheses"));
554 }
555
556 #[test]
557 fn test_python_syntax_validation_valid() {
558 let executor = WorkbenchExecutor::new(
559 Arc::new(ComposioClient::builder().api_key("test").build().unwrap()),
560 "session_123",
561 );
562
563 let result = executor.validate_python_syntax("print('hello')");
564 assert!(result.is_ok());
565 }
566}