Skip to main content

avalonia_mcp_tools/
data_access_pattern_tool.rs

1//! Data Access Pattern tool - Repository and EF Core patterns
2use avalonia_mcp_core::error::AvaloniaMcpError;
3use avalonia_mcp_core::markdown::MarkdownOutputBuilder;
4use rmcp::model::{CallToolResult, Content};
5use rmcp::tool;
6use serde::{Deserialize, Serialize};
7use schemars::JsonSchema;
8
9#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
10pub struct DataAccessPatternParams {
11    pub pattern: Option<String>,
12    pub include_examples: Option<bool>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
16pub struct AsyncDataAccessParams {
17    pub service_name: String,
18    pub include_caching: Option<bool>,
19    pub include_retry: Option<bool>,
20    pub caching_provider: Option<String>,
21}
22
23#[derive(Debug, Clone, Default)]
24pub struct DataAccessPatternTool;
25
26impl DataAccessPatternTool {
27    pub fn new() -> Self { Self }
28
29    #[tool(description = "Generate data access patterns for AvaloniaUI applications. Covers EF Core, Dapper, repository pattern, and database best practices.")]
30    pub async fn generate_data_access_pattern(
31        &self,
32        params: DataAccessPatternParams,
33    ) -> Result<CallToolResult, AvaloniaMcpError> {
34        let include_examples = params.include_examples.unwrap_or(true);
35        let pattern = params.pattern.as_deref().unwrap_or("efcore");
36
37        let output = match pattern {
38            "efcore" => self.generate_efcore(include_examples),
39            "dapper" => self.generate_dapper(include_examples),
40            "repository" => self.generate_repository(include_examples),
41            _ => self.generate_efcore(include_examples),
42        };
43
44        Ok(CallToolResult::success(vec![Content::text(output)]))
45    }
46
47    #[tool(description = "Creates async data access patterns with caching and error handling for AvaloniaUI applications")]
48    pub async fn generate_async_data_access(
49        &self,
50        params: AsyncDataAccessParams,
51    ) -> Result<CallToolResult, AvaloniaMcpError> {
52        if params.service_name.is_empty() {
53            return Err(AvaloniaMcpError::validation("Service name cannot be empty"));
54        }
55
56        let include_caching = params.include_caching.unwrap_or(true);
57        let include_retry = params.include_retry.unwrap_or(true);
58        let cache_provider = params.caching_provider.as_deref().unwrap_or("memory");
59
60        let table_name = params.service_name.replace("Service", "").to_lowercase();
61
62        let mut builder = MarkdownOutputBuilder::new()
63            .heading(1, &format!("Async Data Access Service: {}", params.service_name))
64            .heading(2, "Configuration")
65            .task_list(vec![
66                (true, format!("Service: {}", params.service_name)),
67                (true, format!("Caching: {}", include_caching)),
68                (true, format!("Retry: {}", include_retry)),
69            ])
70            .heading(2, "Service Interface")
71            .code_block(
72                "csharp",
73                &format!(
74                    "public interface I{service}\n{{\n    Task<T?> GetDataAsync<T>(string key);\n    Task SaveDataAsync<T>(string key, T data);\n}}",
75                    service = params.service_name
76                ),
77            )
78            .heading(2, "Async Implementation")
79            .code_block(
80                "csharp",
81                &format!(
82                    r#"public class {service} : I{service}
83{{
84    private readonly IDbConnection _connection;
85
86    public async Task<T?> GetDataAsync<T>(string key)
87    {{
88        return await _connection.QueryFirstOrDefaultAsync<T>(
89            "SELECT * FROM {table} WHERE Key = @Key", new {{ Key = key }});
90    }}
91
92    public async Task SaveDataAsync<T>(string key, T data)
93    {{
94        await _connection.ExecuteAsync(
95            "INSERT OR REPLACE INTO {table} (Key, Data) VALUES (@Key, @Data)",
96            new {{ Key = key, Data = JsonSerializer.Serialize(data) }});
97    }}
98}}"#,
99                    service = params.service_name,
100                    table = table_name
101                ),
102            );
103
104        if include_caching {
105            builder = builder
106                .heading(2, "Caching Implementation")
107                .code_block(
108                    "csharp",
109                    &format!(
110                        r#"// {provider} caching implementation
111private readonly IMemoryCache _cache;
112private readonly TimeSpan _cacheDuration = TimeSpan.FromMinutes(5);
113
114public async Task<T?> GetCachedAsync<T>(string key, Func<Task<T>> factory)
115{{
116    return await _cache.GetOrCreateAsync(key, async entry =>
117    {{
118        entry.AbsoluteExpirationRelativeToNow = _cacheDuration;
119        return await factory();
120    }});
121}}"#,
122                        provider = cache_provider
123                    ),
124                );
125        }
126
127        if include_retry {
128            builder = builder
129                .heading(2, "Retry Policy")
130                .code_block(
131                    "csharp",
132                    r#"// Retry policy with exponential backoff
133public async Task<T> ExecuteWithRetryAsync<T>(Func<Task<T>> operation, int maxRetries = 3)
134{
135    for (int attempt = 1; attempt <= maxRetries; attempt++)
136    {
137        try { return await operation(); }
138        catch (Exception ex) when (attempt < maxRetries)
139        {
140            await Task.Delay(TimeSpan.FromMilliseconds(100 * Math.Pow(2, attempt)));
141        }
142    }
143    throw new Exception("Max retries exceeded");
144}"#,
145                );
146        }
147
148        builder = builder
149            .heading(2, "Performance Considerations")
150            .list(vec![
151                "All DB operations are async",
152                "Use cancellation tokens",
153                "Connection pooling enabled",
154                "Query optimization recommended",
155            ]);
156
157        Ok(CallToolResult::success(vec![Content::text(builder.build())]))
158    }
159
160    fn generate_efcore(&self, include_examples: bool) -> String {
161        let mut builder = MarkdownOutputBuilder::new()
162            .heading(1, "Entity Framework Core Pattern")
163            .paragraph("EF Core ORM for database access in AvaloniaUI applications.")
164            .heading(2, "DbContext Setup")
165            .code_block("csharp", r#"public class AppDbContext : DbContext
166{
167    public AppDbContext(DbContextOptions<AppDbContext> options)
168        : base(options) { }
169    
170    public DbSet<User> Users => Set<User>();
171    public DbSet<Order> Orders => Set<Order>();
172    
173    protected override void OnModelCreating(ModelBuilder modelBuilder)
174    {
175        base.OnModelCreating(modelBuilder);
176        
177        // Configure entities
178        modelBuilder.Entity<User>(entity =>
179        {
180            entity.HasKey(e => e.Id);
181            entity.Property(e => e.Email).IsRequired().HasMaxLength(256);
182            entity.HasIndex(e => e.Email).IsUnique();
183        });
184    }
185}
186
187// Entity
188public class User
189{
190    public int Id { get; set; }
191    public string Name { get; set; } = "";
192    public string Email { get; set; } = "";
193    public DateTime CreatedAt { get; set; } = DateTime.UtcNow;
194}"#);
195
196        if include_examples {
197            builder = builder
198                .heading(2, "DI Registration")
199                .code_block("csharp", r#"// Register DbContext
200services.AddDbContext<AppDbContext>(options =>
201    options.UseSqlite("Data Source=app.db"));
202
203// Register for ViewModel injection
204services.AddScoped<IUserService, UserService>();"#)
205                .heading(2, "Async Queries")
206                .code_block("csharp", r#"// In your service
207public class UserService
208{
209    private readonly AppDbContext _context;
210    
211    public UserService(AppDbContext context) => _context = context;
212    
213    public async Task<User?> GetByIdAsync(int id) =>
214        await _context.Users.FindAsync(id);
215    
216    public async Task<List<User>> GetAllAsync() =>
217        await _context.Users.ToListAsync();
218    
219    public async Task<User> CreateAsync(User user)
220    {
221        _context.Users.Add(user);
222        await _context.SaveChangesAsync();
223        return user;
224    }
225    
226    // Efficient loading
227    public async Task<User?> WithOrdersAsync(int id) =>
228        await _context.Users
229            .Include(u => u.Orders)
230            .FirstOrDefaultAsync(u => u.Id == id);
231}"#);
232        }
233
234        builder.heading(2, "Best Practices")
235            .task_list(vec![(true, "Use async methods"), (true, "Enable sensitive data logging in dev"), (true, "Use migrations for schema changes"), (true, "Implement soft delete"), (false, "Add retry policies for transient errors")])
236            .build()
237    }
238
239    fn generate_dapper(&self, _include_examples: bool) -> String {
240        MarkdownOutputBuilder::new()
241            .heading(1, "Dapper Micro-ORM")
242            .paragraph("Lightweight database access with Dapper.")
243            .heading(2, "Setup")
244            .code_block("csharp", r#"// Install: dotnet add package Dapper
245
246public class UserRepository
247{
248    private readonly IDbConnection _db;
249    
250    public UserRepository(IDbConnection db) => _db = db;
251    
252    public async Task<User?> GetByIdAsync(int id) =>
253        await _db.QueryFirstOrDefaultAsync<User>(
254            "SELECT * FROM Users WHERE Id = @Id", 
255            new { Id = id });
256    
257    public async Task<IEnumerable<User>> GetAllAsync() =>
258        await _db.QueryAsync<User>("SELECT * FROM Users");
259    
260    public async Task<int> CreateAsync(User user) =>
261        await _db.ExecuteAsync(
262            "INSERT INTO Users (Name, Email) VALUES (@Name, @Email)",
263            user);
264}"#)
265            .build()
266    }
267
268    fn generate_repository(&self, _include_examples: bool) -> String {
269        MarkdownOutputBuilder::new()
270            .heading(1, "Repository Pattern")
271            .paragraph("Abstract data access behind repository interface.")
272            .heading(2, "Generic Repository")
273            .code_block("csharp", r#"public interface IRepository<T> where T : class
274{
275    Task<IEnumerable<T>> GetAllAsync();
276    Task<T?> GetByIdAsync(int id);
277    Task<T> AddAsync(T entity);
278    Task DeleteAsync(int id);
279}
280
281public class EfRepository<T> : IRepository<T> where T : class
282{
283    protected readonly DbContext Context;
284    protected readonly DbSet<T> DbSet;
285    
286    public EfRepository(DbContext context)
287    {
288        Context = context;
289        DbSet = context.Set<T>();
290    }
291    
292    public virtual async Task<IEnumerable<T>> GetAllAsync() =>
293        await DbSet.ToListAsync();
294    
295    public virtual async Task<T?> GetByIdAsync(int id) =>
296        await DbSet.FindAsync(id);
297    
298    public virtual async Task<T> AddAsync(T entity)
299    {
300        await DbSet.AddAsync(entity);
301        await Context.SaveChangesAsync();
302        return entity;
303    }
304    
305    public virtual async Task DeleteAsync(int id)
306    {
307        var entity = await DbSet.FindAsync(id);
308        if (entity != null)
309        {
310            DbSet.Remove(entity);
311            await Context.SaveChangesAsync();
312        }
313    }
314}"#)
315            .build()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    #[tokio::test]
323    async fn test_generate_data_access() {
324        let tool = DataAccessPatternTool::new();
325        let result = tool.generate_data_access_pattern(DataAccessPatternParams { pattern: None, include_examples: Some(true) }).await.unwrap();
326        assert!(result.is_error.is_none() || result.is_error == Some(false));
327    }
328
329    #[tokio::test]
330    async fn test_generate_async_data_access_success() {
331        let tool = DataAccessPatternTool::new();
332        let params = AsyncDataAccessParams {
333            service_name: "UserService".to_string(),
334            include_caching: Some(true),
335            include_retry: Some(true),
336            caching_provider: Some("redis".to_string()),
337        };
338        let result = tool.generate_async_data_access(params).await.unwrap();
339        assert!(result.is_error.is_none() || result.is_error == Some(false));
340        assert!(result.content.len() > 0);
341    }
342
343    #[tokio::test]
344    async fn test_generate_async_data_access_no_caching() {
345        let tool = DataAccessPatternTool::new();
346        let params = AsyncDataAccessParams {
347            service_name: "OrderService".to_string(),
348            include_caching: Some(false),
349            include_retry: Some(false),
350            caching_provider: None,
351        };
352        let result = tool.generate_async_data_access(params).await.unwrap();
353        assert!(result.is_error.is_none() || result.is_error == Some(false));
354    }
355
356    #[tokio::test]
357    async fn test_generate_async_data_access_empty_name() {
358        let tool = DataAccessPatternTool::new();
359        let params = AsyncDataAccessParams {
360            service_name: String::new(),
361            include_caching: None,
362            include_retry: None,
363            caching_provider: None,
364        };
365        let result = tool.generate_async_data_access(params).await;
366        assert!(result.is_err());
367    }
368
369    #[tokio::test]
370    async fn test_generate_async_data_access_defaults() {
371        let tool = DataAccessPatternTool::new();
372        let params = AsyncDataAccessParams {
373            service_name: "DataAccess".to_string(),
374            include_caching: None,
375            include_retry: None,
376            caching_provider: None,
377        };
378        let result = tool.generate_async_data_access(params).await.unwrap();
379        assert!(result.is_error.is_none() || result.is_error == Some(false));
380    }
381}