package livekit
import (
"errors"
fmt "fmt"
"strconv"
"strings"
)
type allowedCharacters struct {
ascii [128]bool
utf8 bool
}
func NewAllowedCharacters() *allowedCharacters {
return &allowedCharacters{}
}
func (a *allowedCharacters) AddUTF8() error {
a.utf8 = true
return nil
}
func (a *allowedCharacters) AddNumbers() error {
for r := '0'; r <= '9'; r++ {
a.ascii[r] = true
}
return nil
}
func (a *allowedCharacters) AddLowercaseASCII() error {
for r := 'a'; r <= 'z'; r++ {
a.ascii[r] = true
}
return nil
}
func (a *allowedCharacters) AddUppercaseASCII() error {
for r := 'A'; r <= 'Z'; r++ {
a.ascii[r] = true
}
return nil
}
func (a *allowedCharacters) AddPrintableLienarASCII() {
for i := 0x20; i <= 0x7E; i++ {
a.ascii[i] = true
}
a.ascii[0x09] = true }
func (a *allowedCharacters) Add(chars string) error {
for _, char := range chars {
if int(char) >= len(a.ascii) {
return fmt.Errorf("char %d out of range, consider explicilty adding utf8 characters", char)
}
a.ascii[char] = true
}
return nil
}
func (a *allowedCharacters) Remove(chars string) error {
for _, char := range chars {
if int(char) >= len(a.ascii) {
return fmt.Errorf("char %d out of range, consider explicilty adding utf8 characters", char)
}
a.ascii[char] = false
}
return nil
}
func (a *allowedCharacters) Copy() *allowedCharacters {
return &allowedCharacters{
ascii: a.ascii,
utf8: a.utf8,
}
}
func (a *allowedCharacters) Validate(target string) error {
for _, char := range target {
if int(char) >= len(a.ascii) {
if a.utf8 {
continue
}
return fmt.Errorf("char %d out of range, consider explicilty adding utf8 characters", char)
}
if !a.ascii[char] {
return fmt.Errorf("char %d not allowed", char)
}
}
return nil
}
var tokenCharacters *allowedCharacters
var displayNameCharacters *allowedCharacters
var headerValuesCharacters *allowedCharacters
func init() {
tokenCharacters = NewAllowedCharacters()
tokenCharacters.AddNumbers()
tokenCharacters.AddLowercaseASCII()
tokenCharacters.AddUppercaseASCII()
tokenCharacters.Add("-.!%*_+`'~")
displayNameCharacters = tokenCharacters.Copy()
displayNameCharacters.Add(" \t")
headerValuesCharacters = NewAllowedCharacters()
headerValuesCharacters.AddPrintableLienarASCII()
headerValuesCharacters.AddUTF8()
}
var RequiredRequestHeaders = map[string]bool{
"via": true,
"from": true,
"to": true,
"call-id": true,
"cseq": true,
"max-forwards": true,
}
var RequiredResponseHeaders = map[string]bool{
"via": true,
"from": true,
"to": true,
"call-id": true,
"cseq": true,
}
var FrobiddenSipHeaderNames = map[string]bool{
"accept": true,
"accept-encoding": true,
"accept-language": true,
"allow": true,
"allow-events": true, "call-id": true,
"contact": true,
"content-encoding": true,
"content-length": true,
"content-type": true,
"cseq": true,
"event": true, "expires": true,
"from": true, "max-forwards": true,
"record-route": true,
"refer-to": true, "reply-to": true,
"k": true, "l": true, "m": true, "o": true, "r": true, "t": true, "u": true, "v": true, }
var nameAddrHeaders = map[string]bool{
"from": true,
"to": true,
"contact": true,
"route": true,
"record-route": true,
"reply-to": true,
"p-asserted-identity": true, }
func ValidateHeaderName(name string, restrictNames bool) error {
if name == "" {
return errors.New("header name cannot be empty")
}
if len(name) > 255 {
return errors.New("header name too long (max 255 characters)")
}
if err := tokenCharacters.Validate(name); err != nil {
return fmt.Errorf("header name %s contains invalid characters: %w", name, err)
}
if restrictNames {
lowerName := strings.ToLower(name)
if forbidden, exists := FrobiddenSipHeaderNames[lowerName]; exists && forbidden {
return fmt.Errorf("header name %s not supported", name)
}
}
return nil
}
func ValidateHeaderValue(name, value string) error {
if value == "" {
return nil
}
if len(value) > 1024 {
return fmt.Errorf("header %s: value too long (max 1024 characters)", name)
}
if err := headerValuesCharacters.Validate(value); err != nil {
return fmt.Errorf("header %s: value: %w", name, err)
}
lowerName := strings.ToLower(name)
if _, exists := nameAddrHeaders[lowerName]; exists && false {
if err := validateNameAddrHeader(value); err != nil {
return fmt.Errorf("header %s: value: %w", name, err)
}
}
return nil
}
func findAngleBrackets(value string) (int, int, error) {
start := -1
end := -1
for i, r := range value {
switch r {
case '<':
if start != -1 {
return -1, -1, errors.New("multiple opening brackets")
}
start = i
case '>':
if end != -1 {
return -1, -1, errors.New("multiple closing brackets")
}
end = i
}
}
if (start == -1) != (end == -1) {
return -1, -1, errors.New("mismatched angle brackets")
}
if start > end {
return -1, -1, errors.New("malformed angle brackets")
}
return start, end, nil
}
func validateNameAddrHeader(value string) error {
uri := value
start, end, err := findAngleBrackets(value)
if err != nil {
return err
}
if start >= 0 || end >= 0 {
uri = value[start+1 : end]
if err := validateDisplayName(strings.TrimSpace(value[:start])); err != nil {
return err
}
} else {
if strings.ContainsAny(value, ";,? ") {
return errors.New("bare URI with special characters")
}
}
return validateURI(uri)
}
func validateDisplayName(displayName string) error {
if displayName == "" {
return nil
}
if strings.HasPrefix(displayName, `"`) && strings.HasSuffix(displayName, `"`) {
_, err := strconv.Unquote(displayName)
if err != nil {
return fmt.Errorf("display name: %w", err)
}
return nil
}
if err := displayNameCharacters.Validate(displayName); err != nil {
return fmt.Errorf("display name: %w", err)
}
return nil
}
func validateURI(uri string) error {
scheme := strings.SplitN(uri, ":", 2)[0]
if scheme != "sip" && scheme != "sips" && scheme != "tel" {
return errors.New("uri: scheme not one of sip, sips, or tel")
}
if strings.Contains(uri, " ") {
return errors.New("uri: contains spaces")
}
return nil
}