changepacks_csharp/
finder.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use changepacks_core::{Project, ProjectFinder};
4use quick_xml::Reader;
5use quick_xml::events::Event;
6use std::{
7    collections::HashMap,
8    path::{Path, PathBuf},
9};
10use tokio::fs::read_to_string;
11
12use crate::{package::CSharpPackage, workspace::CSharpWorkspace};
13
14#[derive(Debug)]
15pub struct CSharpProjectFinder {
16    projects: HashMap<PathBuf, Project>,
17    project_files: Vec<&'static str>,
18}
19
20impl Default for CSharpProjectFinder {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl CSharpProjectFinder {
27    pub fn new() -> Self {
28        Self {
29            projects: HashMap::new(),
30            project_files: vec![".csproj"],
31        }
32    }
33
34    /// Extract the project name from the .csproj file path (filename without extension)
35    fn extract_name_from_path(path: &Path) -> Option<String> {
36        path.file_stem()
37            .and_then(|s| s.to_str())
38            .map(|s| s.to_string())
39    }
40}
41
42/// Extract project name from a path string, handling both Windows and Unix separators
43/// Input: "..\CoreLib\CoreLib.csproj" or "../CoreLib/CoreLib.csproj"
44/// Output: "CoreLib"
45fn extract_project_name_from_path(path_str: &str) -> Option<String> {
46    // Split by both Windows (\) and Unix (/) separators
47    let filename = path_str.rsplit(['\\', '/']).next()?;
48
49    // Remove .csproj extension
50    filename.strip_suffix(".csproj").map(|s| s.to_string())
51}
52
53impl CSharpProjectFinder {
54    /// Extract version from .csproj XML content using quick-xml
55    fn extract_version(content: &str) -> Option<String> {
56        let mut reader = Reader::from_str(content);
57        let mut buf = Vec::new();
58        let mut in_property_group = false;
59        let mut in_version = false;
60
61        loop {
62            match reader.read_event_into(&mut buf) {
63                Ok(Event::Start(e)) => {
64                    let name = e.local_name();
65                    if name.as_ref() == b"PropertyGroup" {
66                        in_property_group = true;
67                    } else if in_property_group && name.as_ref() == b"Version" {
68                        in_version = true;
69                    }
70                }
71                Ok(Event::End(e)) => {
72                    let name = e.local_name();
73                    if name.as_ref() == b"PropertyGroup" {
74                        in_property_group = false;
75                    } else if name.as_ref() == b"Version" {
76                        in_version = false;
77                    }
78                }
79                Ok(Event::Text(e)) => {
80                    if in_version && let Ok(text) = e.decode() {
81                        let version = text.trim().to_string();
82                        if !version.is_empty() {
83                            return Some(version);
84                        }
85                    }
86                }
87                Ok(Event::Eof) => break,
88                Err(_) => break,
89                _ => {}
90            }
91            buf.clear();
92        }
93        None
94    }
95
96    /// Extract PackageReference dependencies from .csproj XML content using quick-xml
97    #[allow(dead_code)]
98    fn extract_package_references(content: &str) -> Vec<String> {
99        let mut reader = Reader::from_str(content);
100        let mut buf = Vec::new();
101        let mut packages = Vec::new();
102
103        loop {
104            match reader.read_event_into(&mut buf) {
105                Ok(Event::Empty(e)) | Ok(Event::Start(e)) => {
106                    if e.local_name().as_ref() == b"PackageReference" {
107                        for attr in e.attributes().flatten() {
108                            if attr.key.as_ref() == b"Include"
109                                && let Ok(value) = attr.unescape_value()
110                            {
111                                packages.push(value.to_string());
112                            }
113                        }
114                    }
115                }
116                Ok(Event::Eof) => break,
117                Err(_) => break,
118                _ => {}
119            }
120            buf.clear();
121        }
122        packages
123    }
124
125    /// Extract ProjectReference dependencies from .csproj XML content using quick-xml
126    /// Returns the project names (extracted from paths)
127    fn extract_project_references(content: &str) -> Vec<String> {
128        let mut reader = Reader::from_str(content);
129        let mut buf = Vec::new();
130        let mut projects = Vec::new();
131
132        loop {
133            match reader.read_event_into(&mut buf) {
134                Ok(Event::Empty(e)) | Ok(Event::Start(e)) => {
135                    if e.local_name().as_ref() == b"ProjectReference" {
136                        for attr in e.attributes().flatten() {
137                            if attr.key.as_ref() == b"Include"
138                                && let Ok(value) = attr.unescape_value()
139                            {
140                                // Extract project name from path like "..\CoreLib\CoreLib.csproj"
141                                // Handle both Windows (\) and Unix (/) path separators
142                                if let Some(name) = extract_project_name_from_path(&value) {
143                                    projects.push(name);
144                                }
145                            }
146                        }
147                    }
148                }
149                Ok(Event::Eof) => break,
150                Err(_) => break,
151                _ => {}
152            }
153            buf.clear();
154        }
155        projects
156    }
157
158    /// Check if this project is part of a solution (workspace)
159    /// A project is considered a workspace if there's a .sln file in the same directory
160    fn is_workspace(path: &Path) -> bool {
161        if let Some(parent) = path.parent() {
162            // Check if there's a .sln file in the parent directory
163            if let Ok(entries) = std::fs::read_dir(parent) {
164                for entry in entries.flatten() {
165                    if let Some(ext) = entry.path().extension()
166                        && ext == "sln"
167                    {
168                        return true;
169                    }
170                }
171            }
172        }
173        false
174    }
175}
176
177#[async_trait]
178impl ProjectFinder for CSharpProjectFinder {
179    fn projects(&self) -> Vec<&Project> {
180        self.projects.values().collect::<Vec<_>>()
181    }
182
183    fn projects_mut(&mut self) -> Vec<&mut Project> {
184        self.projects.values_mut().collect::<Vec<_>>()
185    }
186
187    fn project_files(&self) -> &[&str] {
188        &self.project_files
189    }
190
191    async fn visit(&mut self, path: &Path, relative_path: &Path) -> Result<()> {
192        // Check if this is a .csproj file
193        if path.is_file() {
194            let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
195
196            if extension != "csproj" {
197                return Ok(());
198            }
199
200            if self.projects.contains_key(path) {
201                return Ok(());
202            }
203
204            // Read .csproj content
205            let csproj_content = read_to_string(path).await?;
206
207            let name = Self::extract_name_from_path(path);
208            let version = Self::extract_version(&csproj_content);
209            let is_workspace = Self::is_workspace(path);
210
211            let (path_key, mut project) = if is_workspace {
212                (
213                    path.to_path_buf(),
214                    Project::Workspace(Box::new(CSharpWorkspace::new(
215                        name,
216                        version,
217                        path.to_path_buf(),
218                        relative_path.to_path_buf(),
219                    ))),
220                )
221            } else {
222                (
223                    path.to_path_buf(),
224                    Project::Package(Box::new(CSharpPackage::new(
225                        name,
226                        version,
227                        path.to_path_buf(),
228                        relative_path.to_path_buf(),
229                    ))),
230                )
231            };
232
233            // Add ProjectReference dependencies (local project references)
234            for dep in Self::extract_project_references(&csproj_content) {
235                project.add_dependency(&dep);
236            }
237
238            self.projects.insert(path_key, project);
239        }
240        Ok(())
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::fs;
248    use tempfile::TempDir;
249
250    #[tokio::test]
251    async fn test_new() {
252        let finder = CSharpProjectFinder::new();
253        assert_eq!(finder.project_files(), &[".csproj"]);
254        assert_eq!(finder.projects().len(), 0);
255    }
256
257    #[tokio::test]
258    async fn test_default() {
259        let finder = CSharpProjectFinder::default();
260        assert_eq!(finder.project_files(), &[".csproj"]);
261        assert_eq!(finder.projects().len(), 0);
262    }
263
264    #[tokio::test]
265    async fn test_visit_package() {
266        let temp_dir = TempDir::new().unwrap();
267        let csproj_path = temp_dir.path().join("TestProject.csproj");
268        fs::write(
269            &csproj_path,
270            r#"<Project Sdk="Microsoft.NET.Sdk">
271  <PropertyGroup>
272    <Version>1.0.0</Version>
273  </PropertyGroup>
274</Project>
275"#,
276        )
277        .unwrap();
278
279        let mut finder = CSharpProjectFinder::new();
280        finder
281            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
282            .await
283            .unwrap();
284
285        assert_eq!(finder.projects().len(), 1);
286        match finder.projects()[0] {
287            Project::Package(pkg) => {
288                assert_eq!(pkg.name(), Some("TestProject"));
289                assert_eq!(pkg.version(), Some("1.0.0"));
290            }
291            _ => panic!("Expected Package"),
292        }
293
294        temp_dir.close().unwrap();
295    }
296
297    #[tokio::test]
298    async fn test_visit_workspace_with_sln() {
299        let temp_dir = TempDir::new().unwrap();
300        let csproj_path = temp_dir.path().join("TestProject.csproj");
301        let sln_path = temp_dir.path().join("TestSolution.sln");
302
303        fs::write(
304            &csproj_path,
305            r#"<Project Sdk="Microsoft.NET.Sdk">
306  <PropertyGroup>
307    <Version>1.0.0</Version>
308  </PropertyGroup>
309</Project>
310"#,
311        )
312        .unwrap();
313
314        fs::write(&sln_path, "Microsoft Visual Studio Solution File").unwrap();
315
316        let mut finder = CSharpProjectFinder::new();
317        finder
318            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
319            .await
320            .unwrap();
321
322        assert_eq!(finder.projects().len(), 1);
323        match finder.projects()[0] {
324            Project::Workspace(ws) => {
325                assert_eq!(ws.name(), Some("TestProject"));
326                assert_eq!(ws.version(), Some("1.0.0"));
327            }
328            _ => panic!("Expected Workspace"),
329        }
330
331        temp_dir.close().unwrap();
332    }
333
334    #[tokio::test]
335    async fn test_visit_package_without_version() {
336        let temp_dir = TempDir::new().unwrap();
337        let csproj_path = temp_dir.path().join("TestProject.csproj");
338        fs::write(
339            &csproj_path,
340            r#"<Project Sdk="Microsoft.NET.Sdk">
341  <PropertyGroup>
342    <OutputType>Exe</OutputType>
343  </PropertyGroup>
344</Project>
345"#,
346        )
347        .unwrap();
348
349        let mut finder = CSharpProjectFinder::new();
350        finder
351            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
352            .await
353            .unwrap();
354
355        assert_eq!(finder.projects().len(), 1);
356        match finder.projects()[0] {
357            Project::Package(pkg) => {
358                assert_eq!(pkg.name(), Some("TestProject"));
359                assert_eq!(pkg.version(), None);
360            }
361            _ => panic!("Expected Package"),
362        }
363
364        temp_dir.close().unwrap();
365    }
366
367    #[tokio::test]
368    async fn test_visit_non_csproj_file() {
369        let temp_dir = TempDir::new().unwrap();
370        let other_file = temp_dir.path().join("other.xml");
371        fs::write(&other_file, r#"<root>content</root>"#).unwrap();
372
373        let mut finder = CSharpProjectFinder::new();
374        finder
375            .visit(&other_file, &PathBuf::from("other.xml"))
376            .await
377            .unwrap();
378
379        assert_eq!(finder.projects().len(), 0);
380
381        temp_dir.close().unwrap();
382    }
383
384    #[tokio::test]
385    async fn test_visit_directory() {
386        let temp_dir = TempDir::new().unwrap();
387        let dir_path = temp_dir.path().join("some_dir");
388        fs::create_dir_all(&dir_path).unwrap();
389
390        let mut finder = CSharpProjectFinder::new();
391        finder
392            .visit(&dir_path, &PathBuf::from("some_dir"))
393            .await
394            .unwrap();
395
396        assert_eq!(finder.projects().len(), 0);
397
398        temp_dir.close().unwrap();
399    }
400
401    #[tokio::test]
402    async fn test_visit_duplicate() {
403        let temp_dir = TempDir::new().unwrap();
404        let csproj_path = temp_dir.path().join("TestProject.csproj");
405        fs::write(
406            &csproj_path,
407            r#"<Project Sdk="Microsoft.NET.Sdk">
408  <PropertyGroup>
409    <Version>1.0.0</Version>
410  </PropertyGroup>
411</Project>
412"#,
413        )
414        .unwrap();
415
416        let mut finder = CSharpProjectFinder::new();
417        finder
418            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
419            .await
420            .unwrap();
421        finder
422            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
423            .await
424            .unwrap();
425
426        assert_eq!(finder.projects().len(), 1);
427
428        temp_dir.close().unwrap();
429    }
430
431    #[tokio::test]
432    async fn test_visit_multiple_packages() {
433        let temp_dir = TempDir::new().unwrap();
434        let csproj1 = temp_dir.path().join("Project1").join("Project1.csproj");
435        let csproj2 = temp_dir.path().join("Project2").join("Project2.csproj");
436        fs::create_dir_all(csproj1.parent().unwrap()).unwrap();
437        fs::create_dir_all(csproj2.parent().unwrap()).unwrap();
438        fs::write(
439            &csproj1,
440            r#"<Project Sdk="Microsoft.NET.Sdk">
441  <PropertyGroup>
442    <Version>1.0.0</Version>
443  </PropertyGroup>
444</Project>
445"#,
446        )
447        .unwrap();
448        fs::write(
449            &csproj2,
450            r#"<Project Sdk="Microsoft.NET.Sdk">
451  <PropertyGroup>
452    <Version>2.0.0</Version>
453  </PropertyGroup>
454</Project>
455"#,
456        )
457        .unwrap();
458
459        let mut finder = CSharpProjectFinder::new();
460        finder
461            .visit(&csproj1, &PathBuf::from("Project1/Project1.csproj"))
462            .await
463            .unwrap();
464        finder
465            .visit(&csproj2, &PathBuf::from("Project2/Project2.csproj"))
466            .await
467            .unwrap();
468
469        assert_eq!(finder.projects().len(), 2);
470
471        temp_dir.close().unwrap();
472    }
473
474    #[tokio::test]
475    async fn test_projects_mut() {
476        let temp_dir = TempDir::new().unwrap();
477        let csproj_path = temp_dir.path().join("TestProject.csproj");
478        fs::write(
479            &csproj_path,
480            r#"<Project Sdk="Microsoft.NET.Sdk">
481  <PropertyGroup>
482    <Version>1.0.0</Version>
483  </PropertyGroup>
484</Project>
485"#,
486        )
487        .unwrap();
488
489        let mut finder = CSharpProjectFinder::new();
490        finder
491            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
492            .await
493            .unwrap();
494
495        let mut projects = finder.projects_mut();
496        assert_eq!(projects.len(), 1);
497        match &mut projects[0] {
498            Project::Package(pkg) => {
499                assert!(!pkg.is_changed());
500                pkg.set_changed(true);
501                assert!(pkg.is_changed());
502            }
503            _ => panic!("Expected Package"),
504        }
505
506        temp_dir.close().unwrap();
507    }
508
509    #[tokio::test]
510    async fn test_visit_package_with_project_references() {
511        let temp_dir = TempDir::new().unwrap();
512        let csproj_path = temp_dir.path().join("TestProject.csproj");
513        fs::write(
514            &csproj_path,
515            r#"<Project Sdk="Microsoft.NET.Sdk">
516  <PropertyGroup>
517    <Version>1.0.0</Version>
518  </PropertyGroup>
519  <ItemGroup>
520    <PackageReference Include="Newtonsoft.Json" Version="13.0.1" />
521  </ItemGroup>
522  <ItemGroup>
523    <ProjectReference Include="..\CoreLib\CoreLib.csproj" />
524    <ProjectReference Include="..\Utils\Utils.csproj" />
525  </ItemGroup>
526</Project>
527"#,
528        )
529        .unwrap();
530
531        let mut finder = CSharpProjectFinder::new();
532        finder
533            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
534            .await
535            .unwrap();
536
537        let projects = finder.projects();
538        assert_eq!(projects.len(), 1);
539        match projects[0] {
540            Project::Package(pkg) => {
541                assert_eq!(pkg.name(), Some("TestProject"));
542                let deps = pkg.dependencies();
543                // Only ProjectReferences are tracked (not PackageReferences)
544                assert_eq!(deps.len(), 2);
545                assert!(deps.contains("CoreLib"));
546                assert!(deps.contains("Utils"));
547            }
548            _ => panic!("Expected Package"),
549        }
550
551        temp_dir.close().unwrap();
552    }
553
554    #[test]
555    fn test_extract_version() {
556        let content = r#"<Project Sdk="Microsoft.NET.Sdk">
557  <PropertyGroup>
558    <Version>1.2.3</Version>
559  </PropertyGroup>
560</Project>"#;
561        assert_eq!(
562            CSharpProjectFinder::extract_version(content),
563            Some("1.2.3".to_string())
564        );
565
566        let no_version = r#"<Project Sdk="Microsoft.NET.Sdk">
567  <PropertyGroup>
568    <OutputType>Exe</OutputType>
569  </PropertyGroup>
570</Project>"#;
571        assert_eq!(CSharpProjectFinder::extract_version(no_version), None);
572    }
573
574    #[test]
575    fn test_extract_package_references() {
576        let content = r#"<Project Sdk="Microsoft.NET.Sdk">
577  <ItemGroup>
578    <PackageReference Include="Newtonsoft.Json" Version="13.0.1" />
579    <PackageReference Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
580  </ItemGroup>
581</Project>"#;
582        let refs = CSharpProjectFinder::extract_package_references(content);
583        assert_eq!(refs.len(), 2);
584        assert!(refs.contains(&"Newtonsoft.Json".to_string()));
585        assert!(refs.contains(&"System.CommandLine".to_string()));
586    }
587
588    #[test]
589    fn test_extract_project_references() {
590        let content = r#"<Project Sdk="Microsoft.NET.Sdk">
591  <ItemGroup>
592    <ProjectReference Include="..\CoreLib\CoreLib.csproj" />
593    <ProjectReference Include="..\Utils\Utils.csproj" />
594  </ItemGroup>
595</Project>"#;
596        let refs = CSharpProjectFinder::extract_project_references(content);
597        assert_eq!(refs.len(), 2);
598        assert!(refs.contains(&"CoreLib".to_string()));
599        assert!(refs.contains(&"Utils".to_string()));
600    }
601
602    #[test]
603    fn test_extract_project_name_from_path() {
604        // Windows-style paths
605        assert_eq!(
606            super::extract_project_name_from_path(r"..\CoreLib\CoreLib.csproj"),
607            Some("CoreLib".to_string())
608        );
609        assert_eq!(
610            super::extract_project_name_from_path(r"..\..\Utils\Utils.csproj"),
611            Some("Utils".to_string())
612        );
613        // Unix-style paths
614        assert_eq!(
615            super::extract_project_name_from_path("../CoreLib/CoreLib.csproj"),
616            Some("CoreLib".to_string())
617        );
618        // Just filename
619        assert_eq!(
620            super::extract_project_name_from_path("MyProject.csproj"),
621            Some("MyProject".to_string())
622        );
623        // Invalid - no .csproj extension
624        assert_eq!(super::extract_project_name_from_path("MyProject.txt"), None);
625    }
626}