Skip to main content

changepacks_csharp/
finder.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use changepacks_core::{Project, ProjectFinder};
4use quick_xml::Reader;
5use quick_xml::events::Event;
6use std::{
7    collections::HashMap,
8    path::{Path, PathBuf},
9};
10use tokio::fs::read_to_string;
11
12use crate::{package::CSharpPackage, workspace::CSharpWorkspace};
13
14#[derive(Debug)]
15pub struct CSharpProjectFinder {
16    projects: HashMap<PathBuf, Project>,
17    project_files: Vec<&'static str>,
18}
19
20impl Default for CSharpProjectFinder {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl CSharpProjectFinder {
27    #[must_use]
28    pub fn new() -> Self {
29        Self {
30            projects: HashMap::new(),
31            project_files: vec![".csproj"],
32        }
33    }
34
35    /// Extract the project name from the .csproj file path (filename without extension)
36    fn extract_name_from_path(path: &Path) -> Option<String> {
37        path.file_stem()
38            .and_then(|s| s.to_str())
39            .map(std::string::ToString::to_string)
40    }
41}
42
43/// Extract project name from a path string, handling both Windows and Unix separators
44/// Input: `"..\CoreLib\CoreLib.csproj"` or `"../CoreLib/CoreLib.csproj"`
45/// Output: `"CoreLib"`
46fn extract_project_name_from_path(path_str: &str) -> Option<String> {
47    // Split by both Windows (\) and Unix (/) separators
48    let filename = path_str.rsplit(['\\', '/']).next()?;
49
50    // Remove .csproj extension
51    filename
52        .strip_suffix(".csproj")
53        .map(std::string::ToString::to_string)
54}
55
56impl CSharpProjectFinder {
57    /// Extract version from .csproj XML content using quick-xml
58    fn extract_version(content: &str) -> Option<String> {
59        let mut reader = Reader::from_str(content);
60        let mut buf = Vec::new();
61        let mut in_property_group = false;
62        let mut in_version = false;
63
64        loop {
65            match reader.read_event_into(&mut buf) {
66                Ok(Event::Start(e)) => {
67                    let name = e.local_name();
68                    if name.as_ref() == b"PropertyGroup" {
69                        in_property_group = true;
70                    } else if in_property_group && name.as_ref() == b"Version" {
71                        in_version = true;
72                    }
73                }
74                Ok(Event::End(e)) => {
75                    let name = e.local_name();
76                    if name.as_ref() == b"PropertyGroup" {
77                        in_property_group = false;
78                    } else if name.as_ref() == b"Version" {
79                        in_version = false;
80                    }
81                }
82                Ok(Event::Text(e)) => {
83                    if in_version && let Ok(text) = e.decode() {
84                        let version = text.trim().to_string();
85                        if !version.is_empty() {
86                            return Some(version);
87                        }
88                    }
89                }
90                Ok(Event::Eof) | Err(_) => break,
91                _ => {}
92            }
93            buf.clear();
94        }
95        None
96    }
97
98    /// Extract `PackageReference` dependencies from .csproj XML content using quick-xml
99    #[allow(dead_code)]
100    fn extract_package_references(content: &str) -> Vec<String> {
101        let mut reader = Reader::from_str(content);
102        let mut buf = Vec::new();
103        let mut packages = Vec::new();
104
105        loop {
106            match reader.read_event_into(&mut buf) {
107                Ok(Event::Empty(e) | Event::Start(e)) => {
108                    if e.local_name().as_ref() == b"PackageReference" {
109                        for attr in e.attributes().flatten() {
110                            if attr.key.as_ref() == b"Include"
111                                && let Ok(value) = attr.unescape_value()
112                            {
113                                packages.push(value.to_string());
114                            }
115                        }
116                    }
117                }
118                Ok(Event::Eof) | Err(_) => break,
119                _ => {}
120            }
121            buf.clear();
122        }
123        packages
124    }
125
126    /// Extract `ProjectReference` dependencies from .csproj XML content using quick-xml
127    /// Returns the project names (extracted from paths)
128    fn extract_project_references(content: &str) -> Vec<String> {
129        let mut reader = Reader::from_str(content);
130        let mut buf = Vec::new();
131        let mut projects = Vec::new();
132
133        loop {
134            match reader.read_event_into(&mut buf) {
135                Ok(Event::Empty(e) | Event::Start(e)) => {
136                    if e.local_name().as_ref() == b"ProjectReference" {
137                        for attr in e.attributes().flatten() {
138                            if attr.key.as_ref() == b"Include"
139                                && let Ok(value) = attr.unescape_value()
140                            {
141                                // Extract project name from path like "..\CoreLib\CoreLib.csproj"
142                                // Handle both Windows (\) and Unix (/) path separators
143                                if let Some(name) = extract_project_name_from_path(&value) {
144                                    projects.push(name);
145                                }
146                            }
147                        }
148                    }
149                }
150                Ok(Event::Eof) | Err(_) => break,
151                _ => {}
152            }
153            buf.clear();
154        }
155        projects
156    }
157
158    /// Check if this project is part of a solution (workspace)
159    /// A project is considered a workspace if there's a .sln file in the same directory
160    async fn is_workspace(path: &Path) -> bool {
161        if let Some(parent) = path.parent() {
162            // Check if there's a .sln file in the parent directory
163            if let Ok(mut entries) = tokio::fs::read_dir(parent).await {
164                while let Ok(Some(entry)) = entries.next_entry().await {
165                    if let Some(ext) = entry.path().extension()
166                        && ext == "sln"
167                    {
168                        return true;
169                    }
170                }
171            }
172        }
173        false
174    }
175}
176
177#[async_trait]
178impl ProjectFinder for CSharpProjectFinder {
179    fn projects(&self) -> Vec<&Project> {
180        self.projects.values().collect::<Vec<_>>()
181    }
182
183    fn projects_mut(&mut self) -> Vec<&mut Project> {
184        self.projects.values_mut().collect::<Vec<_>>()
185    }
186
187    fn project_files(&self) -> &[&str] {
188        &self.project_files
189    }
190
191    async fn visit(&mut self, path: &Path, relative_path: &Path) -> Result<()> {
192        // Check if this is a .csproj file
193        if path.is_file() {
194            let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
195
196            if extension != "csproj" {
197                return Ok(());
198            }
199
200            if self.projects.contains_key(path) {
201                return Ok(());
202            }
203
204            // Read .csproj content
205            let csproj_content = read_to_string(path).await?;
206
207            let name = Self::extract_name_from_path(path);
208            let version = Self::extract_version(&csproj_content);
209            let is_workspace = Self::is_workspace(path).await;
210
211            let (path_key, mut project) = if is_workspace {
212                (
213                    path.to_path_buf(),
214                    Project::Workspace(Box::new(CSharpWorkspace::new(
215                        name,
216                        version,
217                        path.to_path_buf(),
218                        relative_path.to_path_buf(),
219                    ))),
220                )
221            } else {
222                (
223                    path.to_path_buf(),
224                    Project::Package(Box::new(CSharpPackage::new(
225                        name,
226                        version,
227                        path.to_path_buf(),
228                        relative_path.to_path_buf(),
229                    ))),
230                )
231            };
232
233            // Add ProjectReference dependencies (local project references)
234            for dep in Self::extract_project_references(&csproj_content) {
235                project.add_dependency(&dep);
236            }
237
238            self.projects.insert(path_key, project);
239        }
240        Ok(())
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::fs;
248    use tempfile::TempDir;
249
250    #[tokio::test]
251    async fn test_new() {
252        let finder = CSharpProjectFinder::new();
253        assert_eq!(finder.project_files(), &[".csproj"]);
254        assert_eq!(finder.projects().len(), 0);
255    }
256
257    #[tokio::test]
258    async fn test_default() {
259        let finder = CSharpProjectFinder::default();
260        assert_eq!(finder.project_files(), &[".csproj"]);
261        assert_eq!(finder.projects().len(), 0);
262    }
263
264    #[tokio::test]
265    async fn test_visit_package() {
266        let temp_dir = TempDir::new().unwrap();
267        let csproj_path = temp_dir.path().join("TestProject.csproj");
268        fs::write(
269            &csproj_path,
270            r#"<Project Sdk="Microsoft.NET.Sdk">
271  <PropertyGroup>
272    <Version>1.0.0</Version>
273  </PropertyGroup>
274</Project>
275"#,
276        )
277        .unwrap();
278
279        let mut finder = CSharpProjectFinder::new();
280        finder
281            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
282            .await
283            .unwrap();
284
285        assert_eq!(finder.projects().len(), 1);
286        match finder.projects()[0] {
287            Project::Package(pkg) => {
288                assert_eq!(pkg.name(), Some("TestProject"));
289                assert_eq!(pkg.version(), Some("1.0.0"));
290            }
291            _ => panic!("Expected Package"),
292        }
293
294        temp_dir.close().unwrap();
295    }
296
297    #[tokio::test]
298    async fn test_visit_workspace_with_sln() {
299        let temp_dir = TempDir::new().unwrap();
300        let csproj_path = temp_dir.path().join("TestProject.csproj");
301        let sln_path = temp_dir.path().join("TestSolution.sln");
302
303        fs::write(
304            &csproj_path,
305            r#"<Project Sdk="Microsoft.NET.Sdk">
306  <PropertyGroup>
307    <Version>1.0.0</Version>
308  </PropertyGroup>
309</Project>
310"#,
311        )
312        .unwrap();
313
314        fs::write(&sln_path, "Microsoft Visual Studio Solution File").unwrap();
315
316        let mut finder = CSharpProjectFinder::new();
317        finder
318            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
319            .await
320            .unwrap();
321
322        assert_eq!(finder.projects().len(), 1);
323        match finder.projects()[0] {
324            Project::Workspace(ws) => {
325                assert_eq!(ws.name(), Some("TestProject"));
326                assert_eq!(ws.version(), Some("1.0.0"));
327            }
328            _ => panic!("Expected Workspace"),
329        }
330
331        temp_dir.close().unwrap();
332    }
333
334    #[tokio::test]
335    async fn test_visit_package_without_version() {
336        let temp_dir = TempDir::new().unwrap();
337        let csproj_path = temp_dir.path().join("TestProject.csproj");
338        fs::write(
339            &csproj_path,
340            r#"<Project Sdk="Microsoft.NET.Sdk">
341  <PropertyGroup>
342    <OutputType>Exe</OutputType>
343  </PropertyGroup>
344</Project>
345"#,
346        )
347        .unwrap();
348
349        let mut finder = CSharpProjectFinder::new();
350        finder
351            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
352            .await
353            .unwrap();
354
355        assert_eq!(finder.projects().len(), 1);
356        match finder.projects()[0] {
357            Project::Package(pkg) => {
358                assert_eq!(pkg.name(), Some("TestProject"));
359                assert_eq!(pkg.version(), None);
360            }
361            _ => panic!("Expected Package"),
362        }
363
364        temp_dir.close().unwrap();
365    }
366
367    #[tokio::test]
368    async fn test_visit_non_csproj_file() {
369        let temp_dir = TempDir::new().unwrap();
370        let other_file = temp_dir.path().join("other.xml");
371        fs::write(&other_file, r#"<root>content</root>"#).unwrap();
372
373        let mut finder = CSharpProjectFinder::new();
374        finder
375            .visit(&other_file, &PathBuf::from("other.xml"))
376            .await
377            .unwrap();
378
379        assert_eq!(finder.projects().len(), 0);
380
381        temp_dir.close().unwrap();
382    }
383
384    #[tokio::test]
385    async fn test_visit_directory() {
386        let temp_dir = TempDir::new().unwrap();
387        let dir_path = temp_dir.path().join("some_dir");
388        fs::create_dir_all(&dir_path).unwrap();
389
390        let mut finder = CSharpProjectFinder::new();
391        finder
392            .visit(&dir_path, &PathBuf::from("some_dir"))
393            .await
394            .unwrap();
395
396        assert_eq!(finder.projects().len(), 0);
397
398        temp_dir.close().unwrap();
399    }
400
401    #[tokio::test]
402    async fn test_visit_duplicate() {
403        let temp_dir = TempDir::new().unwrap();
404        let csproj_path = temp_dir.path().join("TestProject.csproj");
405        fs::write(
406            &csproj_path,
407            r#"<Project Sdk="Microsoft.NET.Sdk">
408  <PropertyGroup>
409    <Version>1.0.0</Version>
410  </PropertyGroup>
411</Project>
412"#,
413        )
414        .unwrap();
415
416        let mut finder = CSharpProjectFinder::new();
417        finder
418            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
419            .await
420            .unwrap();
421        finder
422            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
423            .await
424            .unwrap();
425
426        assert_eq!(finder.projects().len(), 1);
427
428        temp_dir.close().unwrap();
429    }
430
431    #[tokio::test]
432    async fn test_visit_multiple_packages() {
433        let temp_dir = TempDir::new().unwrap();
434        let csproj1 = temp_dir.path().join("Project1").join("Project1.csproj");
435        let csproj2 = temp_dir.path().join("Project2").join("Project2.csproj");
436        fs::create_dir_all(csproj1.parent().unwrap()).unwrap();
437        fs::create_dir_all(csproj2.parent().unwrap()).unwrap();
438        fs::write(
439            &csproj1,
440            r#"<Project Sdk="Microsoft.NET.Sdk">
441  <PropertyGroup>
442    <Version>1.0.0</Version>
443  </PropertyGroup>
444</Project>
445"#,
446        )
447        .unwrap();
448        fs::write(
449            &csproj2,
450            r#"<Project Sdk="Microsoft.NET.Sdk">
451  <PropertyGroup>
452    <Version>2.0.0</Version>
453  </PropertyGroup>
454</Project>
455"#,
456        )
457        .unwrap();
458
459        let mut finder = CSharpProjectFinder::new();
460        finder
461            .visit(&csproj1, &PathBuf::from("Project1/Project1.csproj"))
462            .await
463            .unwrap();
464        finder
465            .visit(&csproj2, &PathBuf::from("Project2/Project2.csproj"))
466            .await
467            .unwrap();
468
469        assert_eq!(finder.projects().len(), 2);
470
471        temp_dir.close().unwrap();
472    }
473
474    #[tokio::test]
475    async fn test_projects_mut() {
476        let temp_dir = TempDir::new().unwrap();
477        let csproj_path = temp_dir.path().join("TestProject.csproj");
478        fs::write(
479            &csproj_path,
480            r#"<Project Sdk="Microsoft.NET.Sdk">
481  <PropertyGroup>
482    <Version>1.0.0</Version>
483  </PropertyGroup>
484</Project>
485"#,
486        )
487        .unwrap();
488
489        let mut finder = CSharpProjectFinder::new();
490        finder
491            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
492            .await
493            .unwrap();
494
495        let mut projects = finder.projects_mut();
496        assert_eq!(projects.len(), 1);
497        match &mut projects[0] {
498            Project::Package(pkg) => {
499                assert!(!pkg.is_changed());
500                pkg.set_changed(true);
501                assert!(pkg.is_changed());
502            }
503            _ => panic!("Expected Package"),
504        }
505
506        temp_dir.close().unwrap();
507    }
508
509    #[tokio::test]
510    async fn test_visit_package_with_project_references() {
511        let temp_dir = TempDir::new().unwrap();
512        let csproj_path = temp_dir.path().join("TestProject.csproj");
513        fs::write(
514            &csproj_path,
515            r#"<Project Sdk="Microsoft.NET.Sdk">
516  <PropertyGroup>
517    <Version>1.0.0</Version>
518  </PropertyGroup>
519  <ItemGroup>
520    <PackageReference Include="Newtonsoft.Json" Version="13.0.1" />
521  </ItemGroup>
522  <ItemGroup>
523    <ProjectReference Include="..\CoreLib\CoreLib.csproj" />
524    <ProjectReference Include="..\Utils\Utils.csproj" />
525  </ItemGroup>
526</Project>
527"#,
528        )
529        .unwrap();
530
531        let mut finder = CSharpProjectFinder::new();
532        finder
533            .visit(&csproj_path, &PathBuf::from("TestProject.csproj"))
534            .await
535            .unwrap();
536
537        let projects = finder.projects();
538        assert_eq!(projects.len(), 1);
539        match projects[0] {
540            Project::Package(pkg) => {
541                assert_eq!(pkg.name(), Some("TestProject"));
542                let deps = pkg.dependencies();
543                // Only ProjectReferences are tracked (not PackageReferences)
544                assert_eq!(deps.len(), 2);
545                assert!(deps.contains("CoreLib"));
546                assert!(deps.contains("Utils"));
547            }
548            _ => panic!("Expected Package"),
549        }
550
551        temp_dir.close().unwrap();
552    }
553
554    #[test]
555    fn test_extract_version() {
556        let content = r#"<Project Sdk="Microsoft.NET.Sdk">
557  <PropertyGroup>
558    <Version>1.2.3</Version>
559  </PropertyGroup>
560</Project>"#;
561        assert_eq!(
562            CSharpProjectFinder::extract_version(content),
563            Some("1.2.3".to_string())
564        );
565
566        let no_version = r#"<Project Sdk="Microsoft.NET.Sdk">
567  <PropertyGroup>
568    <OutputType>Exe</OutputType>
569  </PropertyGroup>
570</Project>"#;
571        assert_eq!(CSharpProjectFinder::extract_version(no_version), None);
572    }
573
574    #[test]
575    fn test_extract_package_references() {
576        let content = r#"<Project Sdk="Microsoft.NET.Sdk">
577  <ItemGroup>
578    <PackageReference Include="Newtonsoft.Json" Version="13.0.1" />
579    <PackageReference Include="System.CommandLine" Version="2.0.0-beta4.22272.1" />
580  </ItemGroup>
581</Project>"#;
582        let refs = CSharpProjectFinder::extract_package_references(content);
583        assert_eq!(refs.len(), 2);
584        assert!(refs.contains(&"Newtonsoft.Json".to_string()));
585        assert!(refs.contains(&"System.CommandLine".to_string()));
586    }
587
588    #[test]
589    fn test_extract_project_references() {
590        let content = r#"<Project Sdk="Microsoft.NET.Sdk">
591  <ItemGroup>
592    <ProjectReference Include="..\CoreLib\CoreLib.csproj" />
593    <ProjectReference Include="..\Utils\Utils.csproj" />
594  </ItemGroup>
595</Project>"#;
596        let refs = CSharpProjectFinder::extract_project_references(content);
597        assert_eq!(refs.len(), 2);
598        assert!(refs.contains(&"CoreLib".to_string()));
599        assert!(refs.contains(&"Utils".to_string()));
600    }
601
602    #[test]
603    fn test_extract_project_name_from_path() {
604        // Windows-style paths
605        assert_eq!(
606            super::extract_project_name_from_path(r"..\CoreLib\CoreLib.csproj"),
607            Some("CoreLib".to_string())
608        );
609        assert_eq!(
610            super::extract_project_name_from_path(r"..\..\Utils\Utils.csproj"),
611            Some("Utils".to_string())
612        );
613        // Unix-style paths
614        assert_eq!(
615            super::extract_project_name_from_path("../CoreLib/CoreLib.csproj"),
616            Some("CoreLib".to_string())
617        );
618        // Just filename
619        assert_eq!(
620            super::extract_project_name_from_path("MyProject.csproj"),
621            Some("MyProject".to_string())
622        );
623        // Invalid - no .csproj extension
624        assert_eq!(super::extract_project_name_from_path("MyProject.txt"), None);
625    }
626
627    #[test]
628    fn test_extract_version_end_tag() {
629        let content = r#"<Project><PropertyGroup><Version>
630   1.2.3
631   </Version></PropertyGroup></Project>"#;
632        assert_eq!(
633            CSharpProjectFinder::extract_version(content),
634            Some("1.2.3".to_string())
635        );
636    }
637
638    #[test]
639    fn test_extract_version_malformed_xml() {
640        let content = "<Project><PropertyGroup><Version>1.0.0";
641        // Should not panic - either returns Some or None
642        let _ = CSharpProjectFinder::extract_version(content);
643    }
644
645    #[test]
646    fn test_extract_version_empty_version() {
647        let content = r#"<Project><PropertyGroup><Version>  </Version></PropertyGroup></Project>"#;
648        assert_eq!(CSharpProjectFinder::extract_version(content), None);
649    }
650
651    #[test]
652    fn test_extract_version_with_empty_element() {
653        // Self-closing tags like <IsPackable /> generate Event::Empty,
654        // which exercises the wildcard `_ => {}` arm in extract_version
655        let content = r#"<Project Sdk="Microsoft.NET.Sdk">
656  <PropertyGroup>
657    <IsPackable />
658    <Version>3.2.1</Version>
659  </PropertyGroup>
660</Project>"#;
661        assert_eq!(
662            CSharpProjectFinder::extract_version(content),
663            Some("3.2.1".to_string())
664        );
665    }
666
667    #[test]
668    fn test_extract_version_with_comment() {
669        // XML comments generate Event::Comment, exercising the wildcard arm
670        let content = r#"<Project>
671  <PropertyGroup>
672    <!-- version follows -->
673    <Version>4.0.0</Version>
674  </PropertyGroup>
675</Project>"#;
676        assert_eq!(
677            CSharpProjectFinder::extract_version(content),
678            Some("4.0.0".to_string())
679        );
680    }
681}