use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tower_lsp::lsp_types::Position;
#[derive(Debug, Clone)]
pub struct ImportElement {
pub path: String,
pub line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
pub struct ParseError {
pub message: String,
pub line: u32,
pub character: u32,
pub severity: ErrorSeverity,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum ErrorSeverity {
Error,
Warning,
Info,
}
#[derive(Debug, Clone)]
pub struct ParsedProto {
pub uri: String,
pub package: Option<String>,
pub imports: Vec<ImportElement>,
pub messages: Vec<MessageElement>,
pub enums: Vec<EnumElement>,
pub services: Vec<ServiceElement>,
pub extends: Vec<ExtendElement>,
pub line_to_element: HashMap<u32, ProtoElement>,
pub parse_errors: Vec<ParseError>,
}
#[derive(Debug, Clone)]
pub struct MessageElement {
pub name: String,
pub full_name: String,
pub fields: Vec<FieldElement>,
pub nested_messages: Vec<MessageElement>,
pub nested_enums: Vec<EnumElement>,
pub line: u32,
pub end_line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct FieldElement {
pub name: String,
pub field_type: String,
pub type_name: Option<String>,
pub number: i32,
pub label: Option<FieldLabelProto>,
pub line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
pub struct EnumElement {
pub name: String,
pub full_name: String,
pub values: Vec<EnumValueElement>,
pub line: u32,
pub end_line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
pub struct EnumValueElement {
pub name: String,
pub number: i32,
pub line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct ExtendElement {
pub name: String, pub full_name: String, pub fields: Vec<FieldElement>,
pub line: u32,
pub end_line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
pub struct ServiceElement {
pub name: String,
pub full_name: String,
pub methods: Vec<MethodElement>,
pub line: u32,
pub end_line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
pub struct MethodElement {
pub name: String,
pub input_type: String,
pub output_type: String,
pub client_streaming: bool,
pub server_streaming: bool,
pub line: u32,
pub character: u32,
}
#[derive(Debug, Clone)]
pub enum FieldLabelProto {
Optional,
Required,
Repeated,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum ProtoElement {
Message(MessageElement),
Enum(EnumElement),
Service(ServiceElement),
Field(FieldElement),
Method(MethodElement),
}
pub struct ProtoParser {
cache: Arc<RwLock<HashMap<String, ParsedProto>>>,
}
impl ProtoParser {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn parse(&self, uri: String, content: &str) -> Result<ParsedProto> {
{
let cache = self.cache.read().await;
if let Some(cached) = cache.get(&uri) {
return Ok(cached.clone());
}
}
let parse_result = match proto_parser::Parser::new(content).parse() {
Ok(proto) => self.convert_proto(&uri, &proto),
Err(e) => {
let line = if e.position.line > 0 {
e.position.line as u32 - 1
} else {
0
};
let character = if e.position.column > 0 {
e.position.column as u32 - 1
} else {
0
};
ParsedProto {
uri: uri.clone(),
package: None,
imports: Vec::new(),
messages: Vec::new(),
enums: Vec::new(),
services: Vec::new(),
extends: Vec::new(),
line_to_element: HashMap::new(),
parse_errors: vec![ParseError {
message: e.message.clone(),
line,
character,
severity: ErrorSeverity::Error,
}],
}
}
};
{
let mut cache = self.cache.write().await;
cache.insert(uri.clone(), parse_result.clone());
}
Ok(parse_result)
}
fn convert_proto(&self, uri: &str, proto: &proto_parser::Proto) -> ParsedProto {
let mut package: Option<String> = None;
let mut imports = Vec::new();
let mut messages = Vec::new();
let mut enums = Vec::new();
let mut services = Vec::new();
let mut extends = Vec::new();
let mut line_to_element = HashMap::new();
for element in &proto.elements {
match element {
proto_parser::Element::Package(p) => {
package = Some(p.name.clone());
}
proto_parser::Element::Import(i) => {
imports.push(ImportElement {
path: i.filename.clone(),
line: pos_line(i.position.line),
character: pos_col(i.position.column),
});
}
proto_parser::Element::Message(m) => {
if m.is_extend {
let ext = self.convert_extend(m, &package);
extends.push(ext);
} else {
let msg = self.convert_message(m, &package, "");
line_to_element.insert(msg.line, ProtoElement::Message(msg.clone()));
messages.push(msg);
}
}
proto_parser::Element::Enum(e) => {
let enum_elem = self.convert_enum(e, &package, "");
line_to_element.insert(enum_elem.line, ProtoElement::Enum(enum_elem.clone()));
enums.push(enum_elem);
}
proto_parser::Element::Service(s) => {
let service = self.convert_service(s, &package);
line_to_element
.insert(service.line, ProtoElement::Service(service.clone()));
services.push(service);
}
_ => {} }
}
ParsedProto {
uri: uri.to_string(),
package,
imports,
messages,
enums,
services,
extends,
line_to_element,
parse_errors: Vec::new(),
}
}
fn convert_extend(
&self,
m: &proto_parser::Message,
package: &Option<String>,
) -> ExtendElement {
let name = m.name.clone();
let full_name = if let Some(pkg) = package {
format!("{}.{}", pkg, name)
} else {
name.clone()
};
let mut fields = Vec::new();
let mut last_line = pos_line(m.position.line);
for elem in &m.elements {
match elem {
proto_parser::Element::NormalField(f) => {
let fe = self.convert_normal_field(f);
if fe.line > last_line {
last_line = fe.line;
}
fields.push(fe);
}
_ => {}
}
}
let end_line = if last_line > pos_line(m.position.line) {
last_line + 1
} else {
pos_line(m.position.line) + 1
};
let name_column = m.position.column + "extend".len() + 1;
ExtendElement {
name,
full_name,
fields,
line: pos_line(m.position.line),
end_line,
character: pos_col(name_column),
}
}
fn convert_message(
&self,
m: &proto_parser::Message,
package: &Option<String>,
parent_name: &str,
) -> MessageElement {
let name = m.name.clone();
let full_name = make_full_name(package, parent_name, &name);
let mut fields = Vec::new();
let mut nested_messages = Vec::new();
let mut nested_enums = Vec::new();
let mut last_line = pos_line(m.position.line);
for elem in &m.elements {
match elem {
proto_parser::Element::NormalField(f) => {
let fe = self.convert_normal_field(f);
if fe.line > last_line {
last_line = fe.line;
}
fields.push(fe);
}
proto_parser::Element::MapField(f) => {
let fe = self.convert_map_field(f);
if fe.line > last_line {
last_line = fe.line;
}
fields.push(fe);
}
proto_parser::Element::Oneof(o) => {
for oe in &o.elements {
if let proto_parser::Element::OneofField(of) = oe {
let fe = self.convert_oneof_field(of);
if fe.line > last_line {
last_line = fe.line;
}
fields.push(fe);
}
}
}
proto_parser::Element::Message(nested_m) => {
if !nested_m.is_extend {
let nested = self.convert_message(nested_m, package, &full_name);
if nested.end_line > last_line {
last_line = nested.end_line;
}
nested_messages.push(nested);
}
}
proto_parser::Element::Enum(nested_e) => {
let nested = self.convert_enum(nested_e, package, &full_name);
if nested.end_line > last_line {
last_line = nested.end_line;
}
nested_enums.push(nested);
}
_ => {}
}
}
let end_line = if last_line > pos_line(m.position.line) {
last_line + 1
} else {
pos_line(m.position.line) + 1
};
let name_column = m.position.column + "message".len() + 1;
MessageElement {
name,
full_name,
fields,
nested_messages,
nested_enums,
line: pos_line(m.position.line),
end_line,
character: pos_col(name_column),
}
}
fn convert_normal_field(&self, f: &proto_parser::NormalField) -> FieldElement {
let label = if f.repeated {
Some(FieldLabelProto::Repeated)
} else if f.optional {
Some(FieldLabelProto::Optional)
} else if f.required {
Some(FieldLabelProto::Required)
} else {
None
};
let type_name = if is_builtin_type(&f.field.type_name) {
None
} else {
Some(f.field.type_name.clone())
};
FieldElement {
name: f.field.name.clone(),
field_type: f.field.type_name.clone(),
type_name,
number: f.field.sequence as i32,
label,
line: pos_line(f.field.position.line),
character: pos_col(f.field.position.column),
}
}
fn convert_map_field(&self, f: &proto_parser::MapField) -> FieldElement {
let map_type = format!("map<{}, {}>", f.key_type, f.field.type_name);
FieldElement {
name: f.field.name.clone(),
field_type: map_type,
type_name: None,
number: f.field.sequence as i32,
label: Some(FieldLabelProto::Repeated),
line: pos_line(f.field.position.line),
character: pos_col(f.field.position.column),
}
}
fn convert_oneof_field(&self, f: &proto_parser::OneofField) -> FieldElement {
let type_name = if is_builtin_type(&f.field.type_name) {
None
} else {
Some(f.field.type_name.clone())
};
FieldElement {
name: f.field.name.clone(),
field_type: f.field.type_name.clone(),
type_name,
number: f.field.sequence as i32,
label: None,
line: pos_line(f.field.position.line),
character: pos_col(f.field.position.column),
}
}
fn convert_enum(
&self,
e: &proto_parser::Enum,
package: &Option<String>,
parent_name: &str,
) -> EnumElement {
let name = e.name.clone();
let full_name = make_full_name(package, parent_name, &name);
let mut values = Vec::new();
let mut last_line = pos_line(e.position.line);
for elem in &e.elements {
if let proto_parser::Element::EnumField(ef) = elem {
let line = pos_line(ef.position.line);
if line > last_line {
last_line = line;
}
values.push(EnumValueElement {
name: ef.name.clone(),
number: ef.integer as i32,
line,
character: pos_col(ef.position.column),
});
}
}
let end_line = if last_line > pos_line(e.position.line) {
last_line + 1
} else {
pos_line(e.position.line) + 1
};
let name_column = e.position.column + "enum".len() + 1;
EnumElement {
name,
full_name,
values,
line: pos_line(e.position.line),
end_line,
character: pos_col(name_column),
}
}
fn convert_service(
&self,
s: &proto_parser::Service,
package: &Option<String>,
) -> ServiceElement {
let name = s.name.clone();
let full_name = if let Some(pkg) = package {
format!("{}.{}", pkg, name)
} else {
name.clone()
};
let mut methods = Vec::new();
let mut last_line = pos_line(s.position.line);
for elem in &s.elements {
if let proto_parser::Element::Rpc(rpc) = elem {
let line = pos_line(rpc.position.line);
if line > last_line {
last_line = line;
}
let input_type = qualify_type_name(&rpc.request_type, package);
let output_type = qualify_type_name(&rpc.returns_type, package);
let method_name_column = rpc.position.column + "rpc".len() + 1;
methods.push(MethodElement {
name: rpc.name.clone(),
input_type,
output_type,
client_streaming: rpc.streams_request,
server_streaming: rpc.streams_returns,
line,
character: pos_col(method_name_column),
});
}
}
let end_line = if last_line > pos_line(s.position.line) {
last_line + 1
} else {
pos_line(s.position.line) + 1
};
let name_column = s.position.column + "service".len() + 1;
ServiceElement {
name,
full_name,
methods,
line: pos_line(s.position.line),
end_line,
character: pos_col(name_column),
}
}
#[allow(dead_code)]
pub async fn clear_cache(&self) {
let mut cache = self.cache.write().await;
cache.clear();
}
}
impl Default for ProtoParser {
fn default() -> Self {
Self::new()
}
}
fn pos_line(line: usize) -> u32 {
if line > 0 { line as u32 - 1 } else { 0 }
}
fn pos_col(col: usize) -> u32 {
if col > 0 { col as u32 - 1 } else { 0 }
}
fn make_full_name(package: &Option<String>, parent_name: &str, name: &str) -> String {
if let Some(pkg) = package {
if parent_name.is_empty() {
format!("{}.{}", pkg, name)
} else {
format!("{}.{}.{}", pkg, parent_name, name)
}
} else if parent_name.is_empty() {
name.to_string()
} else {
format!("{}.{}", parent_name, name)
}
}
fn qualify_type_name(type_name: &str, package: &Option<String>) -> String {
if type_name.starts_with('.') {
return type_name.to_string();
}
if let Some(pkg) = package {
format!(".{}.{}", pkg, type_name)
} else {
format!(".{}", type_name)
}
}
fn is_builtin_type(t: &str) -> bool {
matches!(
t,
"double"
| "float"
| "int32"
| "int64"
| "uint32"
| "uint64"
| "sint32"
| "sint64"
| "fixed32"
| "fixed64"
| "sfixed32"
| "sfixed64"
| "bool"
| "string"
| "bytes"
)
}
impl ParsedProto {
#[allow(dead_code)]
pub async fn parse(uri: String, content: &str) -> Result<Self> {
let parser = ProtoParser::new();
parser.parse(uri, content).await
}
pub fn find_element_at_position(&self, position: Position) -> Option<&ProtoElement> {
self.line_to_element.get(&position.line)
}
pub fn find_message_by_name(&self, name: &str) -> Option<&MessageElement> {
self.find_message_recursive(&self.messages, name)
}
fn find_message_recursive<'a>(
&'a self,
messages: &'a [MessageElement],
name: &str,
) -> Option<&'a MessageElement> {
for msg in messages {
if msg.name == name || msg.full_name == name {
return Some(msg);
}
if let Some(found) = self.find_message_recursive(&msg.nested_messages, name) {
return Some(found);
}
}
None
}
pub fn find_enum_by_name(&self, name: &str) -> Option<&EnumElement> {
self.find_enum_recursive(&self.enums, name)
}
fn find_enum_recursive<'a>(
&'a self,
enums: &'a [EnumElement],
name: &str,
) -> Option<&'a EnumElement> {
for e in enums {
if e.name == name || e.full_name == name {
return Some(e);
}
}
for msg in &self.messages {
if let Some(found) = self.find_enum_in_message(msg, name) {
return Some(found);
}
}
None
}
fn find_enum_in_message<'a>(
&'a self,
msg: &'a MessageElement,
name: &str,
) -> Option<&'a EnumElement> {
for e in &msg.nested_enums {
if e.name == name || e.full_name == name {
return Some(e);
}
}
for nested_msg in &msg.nested_messages {
if let Some(found) = self.find_enum_in_message(nested_msg, name) {
return Some(found);
}
}
None
}
pub fn find_service_by_name(&self, name: &str) -> Option<&ServiceElement> {
self.services
.iter()
.find(|s| s.name == name || s.full_name == name)
}
pub fn find_extend_field_by_name(&self, name: &str) -> Option<(&ExtendElement, &FieldElement)> {
for ext in &self.extends {
if let Some(field) = ext.fields.iter().find(|f| f.name == name) {
return Some((ext, field));
}
}
None
}
pub fn find_method_by_name(&self, name: &str) -> Option<(&ServiceElement, &MethodElement)> {
for service in &self.services {
if let Some(method) = service.methods.iter().find(|m| m.name == name) {
return Some((service, method));
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_with_service() {
let content = r#"
syntax = "proto3";
package test;
service UserService {
rpc GetUser(GetUserRequest) returns (GetUserResponse);
rpc ListUsers(ListUsersRequest) returns (stream ListUsersResponse);
rpc UpdateUser(stream UpdateUserRequest) returns (UpdateUserResponse);
}
message GetUserRequest {
string user_id = 1;
}
message GetUserResponse {
User user = 1;
}
message User {
string id = 1;
string name = 2;
int32 age = 3;
}
"#;
let result = ParsedProto::parse("test.proto".to_string(), content).await;
assert!(result.is_ok());
let proto = result.unwrap();
assert_eq!(proto.package, Some("test".to_string()));
assert_eq!(proto.services.len(), 1);
let service = &proto.services[0];
assert_eq!(service.name, "UserService");
assert_eq!(service.methods.len(), 3);
let method = &service.methods[0];
assert_eq!(method.name, "GetUser");
assert_eq!(method.input_type, ".test.GetUserRequest");
assert_eq!(method.output_type, ".test.GetUserResponse");
assert!(!method.client_streaming);
assert!(!method.server_streaming);
let method = &service.methods[1];
assert_eq!(method.name, "ListUsers");
assert!(!method.client_streaming);
assert!(method.server_streaming);
let method = &service.methods[2];
assert_eq!(method.name, "UpdateUser");
assert!(method.client_streaming);
assert!(!method.server_streaming);
}
#[tokio::test]
async fn test_parse_nested_messages() {
let content = r#"
syntax = "proto3";
package test;
message Outer {
string outer_field = 1;
message Inner {
string inner_field = 1;
message Deepest {
int32 deep_field = 1;
}
}
}
"#;
let result = ParsedProto::parse("test.proto".to_string(), content).await;
assert!(result.is_ok());
let proto = result.unwrap();
assert_eq!(proto.messages.len(), 1);
let outer = &proto.messages[0];
assert_eq!(outer.name, "Outer");
assert_eq!(outer.nested_messages.len(), 1);
let inner = &outer.nested_messages[0];
assert_eq!(inner.name, "Inner");
assert_eq!(inner.nested_messages.len(), 1);
let deepest = &inner.nested_messages[0];
assert_eq!(deepest.name, "Deepest");
}
#[tokio::test]
async fn test_parse_extend() {
let content = r#"
syntax = "proto2";
package test;
message Base {
optional string name = 1;
}
extend Base {
optional int32 extra_field = 100;
}
message Other {
optional string value = 1;
}
"#;
let result = ParsedProto::parse("test.proto".to_string(), content).await;
assert!(result.is_ok());
let proto = result.unwrap();
assert_eq!(proto.messages.len(), 2, "Should have exactly 2 messages (Base and Other), not the extend");
assert_eq!(proto.messages[0].name, "Base");
assert_eq!(proto.messages[1].name, "Other");
assert_eq!(proto.extends.len(), 1);
assert_eq!(proto.extends[0].name, "Base");
assert_eq!(proto.extends[0].full_name, "test.Base");
assert_eq!(proto.extends[0].fields.len(), 1);
assert_eq!(proto.extends[0].fields[0].name, "extra_field");
let base_msg = proto.find_message_by_name("Base");
assert!(base_msg.is_some());
let base_msg = base_msg.unwrap();
assert_eq!(base_msg.fields.len(), 1);
assert_eq!(base_msg.fields[0].name, "name");
}
#[tokio::test]
async fn test_extend_field_lookup() {
let content = r#"
syntax = "proto2";
package tlvpickle;
message MethodOptions {
optional string name = 1;
}
extend MethodOptions {
optional int32 CmdID = 1000000;
optional string RpcRouteMethod = 1000015;
optional string Brief = 1000005;
}
"#;
let result = ParsedProto::parse("skbuiltintype.proto".to_string(), content).await;
assert!(result.is_ok());
let proto = result.unwrap();
assert_eq!(proto.messages.len(), 1);
assert_eq!(proto.messages[0].name, "MethodOptions");
assert_eq!(proto.extends.len(), 1);
assert_eq!(proto.extends[0].name, "MethodOptions");
assert_eq!(proto.extends[0].fields.len(), 3);
let result = proto.find_extend_field_by_name("RpcRouteMethod");
assert!(result.is_some());
let (ext, field) = result.unwrap();
assert_eq!(ext.name, "MethodOptions");
assert_eq!(field.name, "RpcRouteMethod");
let result = proto.find_extend_field_by_name("CmdID");
assert!(result.is_some());
let result = proto.find_extend_field_by_name("NonExistent");
assert!(result.is_none());
}
#[tokio::test]
async fn test_message_character_points_to_name_not_keyword() {
let content = "syntax = \"proto3\";\n\nmessage UserRequest {\n string user_id = 1;\n}\n";
let result = ParsedProto::parse("test.proto".to_string(), content).await;
assert!(result.is_ok());
let proto = result.unwrap();
let msg = proto.find_message_by_name("UserRequest").unwrap();
assert_eq!(
msg.character, 8,
"BUG: message character is {} but should be 8 (pointing to name, not keyword). \
Current value points to '{}' instead of the name 'UserRequest'.",
msg.character,
&"message UserRequest {"[msg.character as usize..msg.character as usize + "UserRequest".len()]
);
}
#[tokio::test]
async fn test_enum_character_points_to_name_not_keyword() {
let content = "syntax = \"proto3\";\n\nenum Status {\n UNKNOWN = 0;\n}\n";
let result = ParsedProto::parse("test.proto".to_string(), content).await;
assert!(result.is_ok());
let proto = result.unwrap();
let e = proto.find_enum_by_name("Status").unwrap();
assert_eq!(
e.character, 5,
"BUG: enum character is {} but should be 5 (pointing to name, not keyword)",
e.character
);
}
#[tokio::test]
async fn test_service_character_points_to_name_not_keyword() {
let content = "syntax = \"proto3\";\n\nservice UserService {\n rpc Get(Req) returns (Resp);\n}\nmessage Req {}\nmessage Resp {}\n";
let result = ParsedProto::parse("test.proto".to_string(), content).await;
assert!(result.is_ok());
let proto = result.unwrap();
let svc = proto.find_service_by_name("UserService").unwrap();
assert_eq!(
svc.character, 8,
"BUG: service character is {} but should be 8 (pointing to name, not keyword)",
svc.character
);
}
}