Skip to main content

lux_cli/project/
new.rs

1use std::{error::Error, fmt::Display, path::PathBuf, str::FromStr};
2
3use clap::Args;
4use eyre::{eyre, Result};
5use inquire::{
6    ui::{RenderConfig, Styled},
7    validator::Validation,
8    Confirm, Select, Text,
9};
10use itertools::Itertools;
11use spdx::LicenseId;
12use spinners::{Spinner, Spinners};
13
14use crate::utils::github_metadata::{self, RepoMetadata};
15use lux_lib::{
16    config::Config,
17    package::PackageReq,
18    project::{Project, PROJECT_TOML},
19};
20
21// TODO:
22// - Automatically detect build type to insert into rockspec by inspecting the current repo.
23//   E.g. if there is a `Cargo.toml` in the project root we can infer the user wants to use the
24//   Rust build backend.
25
26/// The type of directory to create when making the project.
27#[derive(Debug, Clone, clap::ValueEnum)]
28enum SourceDirType {
29    Src,
30    Lua,
31}
32
33impl Display for SourceDirType {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            Self::Src => write!(f, "src"),
37            Self::Lua => write!(f, "lua"),
38        }
39    }
40}
41
42#[derive(Args)]
43pub struct NewProject {
44    /// The directory of the project.
45    target: PathBuf,
46
47    /// The project's name.
48    #[arg(long)]
49    name: Option<String>,
50
51    /// The description of the project.
52    #[arg(long)]
53    description: Option<String>,
54
55    /// The license of the project. Generic license names will be inferred.
56    #[arg(long, value_parser = clap_parse_license)]
57    license: Option<LicenseId>,
58
59    /// The maintainer of this project. Does not have to be the code author.
60    #[arg(long)]
61    maintainer: Option<String>,
62
63    /// A comma-separated list of labels to apply to this project.
64    #[arg(long, value_parser = clap_parse_list)]
65    labels: Option<std::vec::Vec<String>>, // Note: full qualified name required, see https://github.com/clap-rs/clap/issues/4626
66
67    /// A version constraint on the required Lua version for this project.
68    /// Examples: ">=5.1", "5.1"
69    #[arg(long, value_parser = clap_parse_version)]
70    lua_versions: Option<PackageReq>,
71
72    #[arg(long)]
73    main: Option<SourceDirType>,
74}
75
76struct NewProjectValidated {
77    target: PathBuf,
78    name: String,
79    description: String,
80    maintainer: String,
81    labels: Vec<String>,
82    lua_versions: PackageReq,
83    main: SourceDirType,
84    license: Option<LicenseId>,
85}
86
87fn clap_parse_license(s: &str) -> std::result::Result<LicenseId, String> {
88    match validate_license(s) {
89        Ok(Validation::Valid) => unsafe { Ok(parse_license_unchecked(s)) },
90        Err(_) | Ok(Validation::Invalid(_)) => {
91            Err(format!("unable to identify license {s}, please try again!"))
92        }
93    }
94}
95
96fn clap_parse_version(input: &str) -> std::result::Result<PackageReq, String> {
97    PackageReq::from_str(format!("lua {input}").as_str()).map_err(|err| err.to_string())
98}
99
100fn clap_parse_list(input: &str) -> std::result::Result<Vec<String>, String> {
101    if let Some((pos, char)) = input
102        .chars()
103        .find_position(|&c| c != '-' && c != '_' && c != ',' && c.is_ascii_punctuation())
104    {
105        Err(format!(
106            r#"Unexpected punctuation '{char}' found at column {pos}.
107    Lists are comma separated but names should not contain punctuation!"#
108        ))
109    } else {
110        Ok(input.split(',').map(|str| str.trim().to_string()).collect())
111    }
112}
113
114/// Parses a license.
115///
116/// # Security
117///
118/// WARNING: This should only be invoked after validating the license with [`validate_license`].
119unsafe fn parse_license_unchecked(input: &str) -> LicenseId {
120    spdx::imprecise_license_id(input).unwrap_unchecked().0
121}
122
123fn validate_license(input: &str) -> std::result::Result<Validation, Box<dyn Error + Send + Sync>> {
124    if input == "none" {
125        return Ok(Validation::Valid);
126    }
127
128    Ok(
129        match spdx::imprecise_license_id(input).ok_or(format!(
130            "Unable to identify license '{input}', please try again!",
131        )) {
132            Ok(_) => Validation::Valid,
133            Err(err) => Validation::Invalid(err.into()),
134        },
135    )
136}
137
138pub async fn write_project_rockspec(cli_flags: NewProject, config: Config) -> Result<()> {
139    let project = Project::from_exact(cli_flags.target.clone())?;
140    let render_config = RenderConfig::default_colored()
141        .with_prompt_prefix(Styled::new(">").with_fg(inquire::ui::Color::LightGreen));
142
143    // If the project already exists then ask for override confirmation
144    if project.is_some() && config.no_prompt()
145        || !Confirm::new("Target directory already has a project, write anyway?")
146            .with_default(false)
147            .with_help_message(&format!("This may overwrite your existing {PROJECT_TOML}",))
148            .with_render_config(render_config)
149            .prompt()?
150    {
151        return Err(eyre!("cancelled creation of project (already exists)"));
152    };
153
154    let validated = match cli_flags {
155        // If all parameters are provided then don't bother prompting the user
156        NewProject {
157            description: Some(description),
158            main: Some(main),
159            labels: Some(labels),
160            lua_versions: Some(lua_versions),
161            maintainer: Some(maintainer),
162            name: Some(name),
163            license,
164            target,
165        } => Ok::<_, eyre::Report>(NewProjectValidated {
166            description,
167            labels,
168            license,
169            lua_versions,
170            main,
171            maintainer,
172            name,
173            target,
174        }),
175
176        NewProject {
177            description,
178            labels,
179            license,
180            lua_versions,
181            main,
182            maintainer,
183            name,
184            target,
185        } => {
186            let mut spinner = Spinner::new(
187                Spinners::Dots,
188                "Fetching remote repository metadata... ".into(),
189            );
190
191            let repo_metadata = match github_metadata::get_metadata_for(Some(&target)).await {
192                Ok(value) => value.map_or_else(|| RepoMetadata::default(&target), Ok),
193                Err(_) => {
194                    println!("Could not fetch remote repo metadata, defaulting to empty values.");
195
196                    RepoMetadata::default(&target)
197                }
198            }?;
199
200            spinner.stop_and_persist("✔", "Fetched remote repository metadata.".into());
201
202            let package_name = name.map_or_else(
203                || {
204                    Text::new("Package name:")
205                        .with_default(&repo_metadata.name)
206                        .with_help_message("A folder with the same name will be created for you.")
207                        .with_render_config(render_config)
208                        .prompt()
209                },
210                Ok,
211            )?;
212
213            let description = description.map_or_else(
214                || {
215                    Text::new("Description:")
216                        .with_default(&repo_metadata.description.unwrap_or_default())
217                        .with_render_config(render_config)
218                        .prompt()
219                },
220                Ok,
221            )?;
222
223            let license = license.map_or_else(
224                || {
225                    Ok::<_, eyre::Error>(
226                        match Text::new("License:")
227                            .with_default(&repo_metadata.license.unwrap_or("none".into()))
228                            .with_help_message("Type 'none' for no license")
229                            .with_validator(validate_license)
230                            .with_render_config(render_config)
231                            .prompt()?
232                            .as_str()
233                        {
234                            "none" => None,
235                            license => unsafe { Some(parse_license_unchecked(license)) },
236                        },
237                    )
238                },
239                |license| Ok(Some(license)),
240            )?;
241
242            let labels = labels.or(repo_metadata.labels).map_or_else(
243                || {
244                    Ok::<_, eyre::Error>(
245                        Text::new("Labels:")
246                            .with_placeholder("web,filesystem")
247                            .with_help_message("Labels are comma separated")
248                            .prompt()?
249                            .split(',')
250                            .map(|label| label.trim().to_string())
251                            .collect_vec(),
252                    )
253                },
254                Ok,
255            )?;
256
257            let maintainer = maintainer.map_or_else(
258                || {
259                    let prompt = Text::new("Maintainer:");
260                    if let Some(default_maintainer) = repo_metadata
261                        .contributors
262                        .first()
263                        .cloned()
264                        .or_else(|| whoami::realname().ok())
265                    {
266                        prompt.with_default(&default_maintainer).prompt()
267                    } else {
268                        prompt.prompt()
269                    }
270                },
271                Ok,
272            )?;
273
274            let lua_versions = lua_versions.map_or_else(
275                || {
276                    Ok::<_, eyre::Report>(
277                        format!(
278                            "lua >= {}",
279                            Select::new(
280                                "What is the lowest Lua version you support?",
281                                vec!["5.1", "5.2", "5.3", "5.4", "5.5"]
282                            )
283                            .without_filtering()
284                            .with_vim_mode(true)
285                            .with_help_message(
286                                "This is equivalent to the 'lua >= {version}' constraint."
287                            )
288                            .prompt()?
289                        )
290                        .parse()?,
291                    )
292                },
293                Ok,
294            )?;
295
296            Ok(NewProjectValidated {
297                target,
298                name: package_name,
299                description,
300                labels,
301                license,
302                lua_versions,
303                maintainer,
304                main: main.unwrap_or(SourceDirType::Src),
305            })
306        }
307    }?;
308
309    let _ = std::fs::create_dir_all(&validated.target);
310
311    let rocks_path = validated.target.join(PROJECT_TOML);
312
313    std::fs::write(
314        &rocks_path,
315        format!(
316            r#"
317package = "{package_name}"
318version = "0.1.0"
319lua = "{lua_version_req}"
320
321[description]
322summary = "{summary}"
323maintainer = "{maintainer}"
324labels = [ {labels} ]
325{license}
326
327[dependencies]
328# Add your dependencies here
329# `busted = ">=2.0"`
330
331[run]
332args = [ "{main}/main.lua" ]
333
334[build]
335type = "builtin"
336    "#,
337            package_name = validated.name,
338            summary = validated.description,
339            license = validated
340                .license
341                .map(|license| format!(r#"license = "{}""#, license.name))
342                .unwrap_or_default(),
343            maintainer = validated.maintainer,
344            labels = validated
345                .labels
346                .into_iter()
347                .map(|label| "\"".to_string() + &label + "\"")
348                .join(", "),
349            lua_version_req = validated.lua_versions.version_req(),
350            main = validated.main,
351        )
352        .trim(),
353    )?;
354
355    let main_dir = validated.target.join(validated.main.to_string());
356    if main_dir.exists() {
357        eprintln!(
358            "Directory `{}/` already exists - we won't make any changes to it.",
359            main_dir.display()
360        );
361    } else {
362        std::fs::create_dir(&main_dir)?;
363        std::fs::write(main_dir.join("main.lua"), r#"print("Hello world!")"#)?;
364    }
365
366    println!("All done!");
367
368    Ok(())
369}
370
371// TODO(vhyrro): Add tests